PyTorch實例:基於自編碼器的圖形去噪


  去噪自編碼器模擬人類視覺機制能夠自動忍受圖像的噪聲來識別圖片。自編碼器的目標是要學習一個近似的恆等函數,使得輸出近似等於輸入。去噪自編碼器采用隨機的部分帶噪輸入來解決恆等函數問題,自編碼器能夠獲得輸入的良好表征,該表征使得自編碼器能進行去噪或恢復。

  下面是代碼:

#加載庫和配置參數
#去噪自編碼器
import torch
import torch.nn as nn
import torch.utils as utils
from torch.autograd import Variable
import torchvision.datasets as dset
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
#配置參數
torch.manual_seed(1)
n_epoch=200
batch_size=100
learning_rate=0.002

#下載圖片庫訓練集
mnist_train=dset.MNIST("./",train=True,transform=transforms.ToTensor(),target_transform=None,download=True)
train_loader=torch.utils.data.DataLoader(dataset=mnist_train,batch_size=batch_size,shuffle=True)

#Encoder和Decoder模型設置
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder,self).__init__()
        self.layer1=nn.Sequential(
            nn.Conv2d(1,32,3,padding=1),#batch*32*28*28
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.Conv2d(32,32,3,padding=1),#batch*32*28*28
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.Conv2d(32,64,3,padding=1),#batch*64*28*28
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.Conv2d(64,64,3,padding=1),#batch*64*28*28
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.MaxPool2d(2,2)#batch*64*14*14
        )
        self.layer2=nn.Sequential(
            nn.Conv2d(64,128,3,padding=1),#batch*128*14*14
            nn.ReLU(),
            nn.BatchNorm2d(128),
            nn.Conv2d(128,128,3,padding=1),#batch*128*14*14
            nn.ReLU(),
            nn.BatchNorm2d(128),
            nn.MaxPool2d(2,2),
            nn.Conv2d(128,256,3,padding=1),#batch*256*14*14
            nn.ReLU()
        )
    def forward(self,x):
        out=self.layer1(x)
        out=self.layer2(out)
        out=out.view(batch_size,-1)
        return out

encoder=Encoder().cuda()

#decoder設置
class Decoder(nn.Module):
    def __init__(self):
        super(Decoder,self).__init__()
        self.layer1=nn.Sequential(
            nn.ConvTranspose2d(256,128,3,2,1,1),#batch*128*14*14
            nn.ReLU(),
            nn.BatchNorm2d(128),
            nn.ConvTranspose2d(128,128,3,1,1),#batch*128*14*14
            nn.ReLU(),
            nn.BatchNorm2d(128),
            nn.ConvTranspose2d(128,64,3,1,1),#batch*64*14*14
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.ConvTranspose2d(64,64,3,1,1),#batch*64*14*14
            nn.ReLU(),
            nn.BatchNorm2d(64)
        )
        self.layer2=nn.Sequential(
            nn.ConvTranspose2d(64,32,3,1,1),#batch*32*14*14
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.ConvTranspose2d(32,32,3,1,1),#batch*32*14*14
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.ConvTranspose2d(32,1,3,2,1,1),#batch*1*28*28
            nn.ReLU()
        )
    def forward(self,x):
        out=x.view(batch_size,256,7,7)
        out=self.layer1(out)
        out=self.layer2(out)
        return out
decoder=Decoder().cuda()

###Loss 函數和優化器
parameters=list(encoder.parameters())+list(decoder.parameters())
loss_func=nn.MSELoss()
optimizer=torch.optim.Adam(parameters,lr=learning_rate)

###自編碼器訓練
#添加噪聲
noise=torch.rand(batch_size,1,28,28)
for I in range(n_epoch):
    for image,label in train_loader:
        image_n=torch.mul(image+0.25,0.1*noise)
        image=Variable(image).cuda()
        image_n=Variable(image_n).cuda()
        optimizer.zero_grad()
        output=encoder(image_n)
        output=decoder(output)
        loss=loss_func(output,image)
        loss.backward()
        optimizer.step()
        break
    print('epoch[{}/{}],loss:{:.4f}'.format(I+1,n_epoch,loss.item()))

####帶噪圖片和去噪圖片對比
img=image[0].cpu()
input_img=image_n[0].cpu()
output_img=output[0].cpu()
origin=img.data.numpy()
inp=input_img.data.numpy()
out=output_img.data.numpy()
plt.figure('denoising autoencoder')
plt.subplot(131)
plt.imshow(origin[0],cmap='gray')
plt.subplot(132)
plt.imshow(inp[0],cmap='gray')
plt.subplot(133)
plt.imshow(out[0],cmap='gray')
plt.show()
print(label[0])


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM