下載pytorch_wavelets:
git clone https://github.com/fbcotter/pytorch_wavelets
然后安裝:
cd pytorch_wavelets
pip install .
返回:
Successfully built pytorch-wavelets Installing collected packages: pytorch-wavelets Successfully installed pytorch-wavelets-1.2.2
查看你能夠使用的變換方法:
>>> import pywt >>> pywt.wavelist('haar') ['haar'] >>> pywt.wavelist('db') ['db1', 'db2', 'db3', 'db4', 'db5', 'db6', 'db7', 'db8', 'db9', 'db10', 'db11', 'db12', 'db13', 'db14', 'db15', 'db16', 'db17', 'db18', 'db19', 'db20', 'db21', 'db22', 'db23', 'db24', 'db25', 'db26', 'db27', 'db28', 'db29', 'db30', 'db31', 'db32', 'db33', 'db34', 'db35', 'db36', 'db37', 'db38']
詳情可見:
https://pywavelets.readthedocs.io/en/latest/ref/wavelets.html
從pytorch_wavelets的源碼https://github.com/fbcotter/pytorch_wavelets/blob/master/pytorch_wavelets/dwt/transform2d.py中可見其wave參數使用的是pywt.Wavelet:
class DWTForward(nn.Module): """ Performs a 2d DWT Forward decomposition of an image Args: J (int): Number of levels of decomposition wave (str or pywt.Wavelet): Which wavelet to use. Can be a string to pass to pywt.Wavelet constructor, can also be a pywt.Wavelet class, or can be a two tuple of array-like objects for the analysis low and high pass filters. mode (str): 'zero', 'symmetric', 'reflect' or 'periodization'. The padding scheme separable (bool): whether to do the filtering separably or not (the naive implementation can be faster on a gpu). """ def __init__(self, J=1, wave='db1', mode='zero'): super().__init__() if isinstance(wave, str): wave = pywt.Wavelet(wave) if isinstance(wave, pywt.Wavelet): h0_col, h1_col = wave.dec_lo, wave.dec_hi h0_row, h1_row = h0_col, h1_col ...
舉例說明:
#coding:utf-8 import torch.nn as nn import torch import os, torchvision from PIL import Image from torchvision import transforms as trans def test3(): from pytorch_wavelets import DWTForward, DWTInverse # (or import DWT, IDWT) #J為分解的層次數,wave表示使用的變換方法 xfm = DWTForward(J=1, mode='zero', wave='haar') # Accepts all wave types available to PyWavelets ifm = DWTInverse(mode='zero', wave='haar') img = Image.open('./1.jpg') transform = trans.Compose([ trans.ToTensor() ]) img = transform(img).unsqueeze(0) Yl, Yh = xfm(img) print(Yl.shape) print(len(Yh)) # print(Yh[0].shape) for i in range(len(Yh)): print(Yh[i].shape) if i == len(Yh)-1: h = torch.zeros([4,3,Yh[i].size(3),Yh[i].size(3)]).float() h[0,:,:,:] = Yl else: h = torch.zeros([3,3,Yh[i].size(3),Yh[i].size(3)]).float() for j in range(3): if i == len(Yh)-1: h[j+1,:,:,:] = Yh[i][:,:,j,:,:] else: h[j,:,:,:] = Yh[i][:,:,j,:,:] if i == len(Yh)-1: img_grid = torchvision.utils.make_grid(h, 2) #一行2張圖片 else: img_grid = torchvision.utils.make_grid(h, 3) torchvision.utils.save_image(img_grid, 'img_grid_{}.jpg'.format(i)) if __name__ == '__main__': test3()
返回:
(deeplearning) bogon:learning user$ python delete.py torch.Size([1, 3, 56, 56]) 1 torch.Size([1, 3, 3, 56, 56])
效果如下:

從這個結果上看和MWCNN中使用的haar小波變換 pytorch 的差不多
輸出Yl的大小為(N,Cin,Hin′,Win′),即Hin′和Win′即最后一次小波變換輸出的LL,比如輸入大小為112*112,進行一層小波變換后Hin′和Win′即56*56;兩層即28*28;三層為14*14
Yh的大小為list(N,Cin,3,Hin″,Win″),這個list的大小即進行的小波變換的次數,Yh[0]即一層小波變換的HL、LH和HH,Yh[1]即二層小波變換的HL、LH和HH,Yh[3]即三層小波變換的HL、LH和HH
(N,Cin,3,Hin″,Win″)中的3表示的是HL、LH和HH
詳細內容可見https://pytorch-wavelets.readthedocs.io/en/latest/dwt.html
如果進行的是三層小波,J=3:
返回:
(deeplearning) bogon:learning user$ python delete.py torch.Size([1, 3, 14, 14]) 3 torch.Size([1, 3, 3, 56, 56]) torch.Size([1, 3, 3, 28, 28]) torch.Size([1, 3, 3, 14, 14])
效果:



如果J=2,是兩層,返回:
(deeplearning) bogon:learning user$ python delete.py torch.Size([1, 3, 28, 28]) 2 torch.Size([1, 3, 3, 56, 56]) torch.Size([1, 3, 3, 28, 28])
效果:


