基於PyTorch實現圖像去模糊-學習
任務描述
- 相機的抖動、快速運動的物體都會導致拍攝出模糊的圖像,景深變化也會使圖像進一步模糊。
- 對於傳統方法來說,要想估計出每個像素點對應的 “blur kernel” 幾乎是不可行的。因此,傳統方法常常需要對模糊源作出假設,將 “blur kernel” 參數化。顯然,這類方法不足以解決實際中各種復雜因素引起的圖像模糊。
- 卷積神經網絡能夠從圖像中提取出復雜的特征,從而使得模型能夠適應各種場景。
- 本教程以 CVPR2017 的 《Deep Multi-scale Convolutional Neural Network for Dynamic Scene Deblurring》 為例,來完成圖像去模糊的任務。
方法概述
- 利用pytorch深度學習工具實現一個端到端的圖像去模糊模型,通過參數設置、加載數據、構建模型、訓練模型和測試用例依次實現一個圖像去模糊工具,在訓練和預處理過程中通過可視化監督訓練過程。
- 模型采用了殘差形式的CNN,輸入和輸出都采用高斯金字塔(Gaussian pyramid)的形式。
- 整個網絡結構由三個相似的CNN構成,分別對應輸入金字塔中的每一層。網絡最前面是分辨率最低的子網絡(coarest level network),在這個子網絡最后,是“upconvolution layer”,將重建的低分辨率圖像放大為高分辨率圖像,然后和高一層的子網絡的輸入連接在一起,作為上層網絡的輸入。
%config Completer.use_jedi = False
#!pip install pytorch_msssim -i https://pypi.tuna.tsinghua.edu.cn/simple
# !jupyter nbextension enable --py widgetsnbextension
import torch
import numpy as np
import os
import random
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.optim as optim
from PIL import Image
from tensorboardX import SummaryWriter
from torchsummary import summary
from torch.optim import lr_scheduler
from torch.utils import data
from torchvision import transforms
from tqdm.notebook import tqdm
import pytorch_msssim # 用於計算指標 ssim 和 mssim
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
參數設置
class Config():
def __init__(self,name="Configs"):
# train set
self.data_dir = 'datasets/train' # 訓練集目錄
self.patch_size = 256 # 輸入模型的patch的尺寸
self.batch_size= 2 #16 # 訓練時每個batch中的樣本個數
self.n_threads = 1 # 用於加載數據的線程數
# test set
self.test_data_dir = 'datasets/test' # 測試集目錄
self.test_batch_size=1 # 測試時的 batch_size
# model
self.multi = True # 模型采用多尺度方法True
self.skip = True # 模型采用滑動連接方法
self.n_resblocks = 3 #9 # resblock的個數
self.n_feats = 8 #64 #feature map的個數
# optimization
self.lr = 1e-4 # 初始學習率
self.epochs =5 #800 # 訓練epoch的數目
self.lr_step_size = 600 #采用步進學習率策略所用的 step_size
self.lr_gamma = 0.1 #每 lr_step_size后,學習率變成 lr * lr_gamma
# global
self.name = name #配置的名稱
self.save_dir = 'temp/result' # 保存訓練過程中所產生數據的目錄
self.save_cp_dir = 'temp/models' # 保存 checkpoint的目錄
self.imgs_dir = 'datasets/pictures' # 此 notebook所需的圖片目錄
if not os.path.exists(self.save_dir):
os.makedirs(self.save_dir)
if not os.path.exists(self.save_cp_dir):
os.makedirs(self.save_cp_dir)
# if not os.path.exists(self.data_dir):
# os.makedirs(self.data_dir)
# if not os.path.exists(self.test_data_dir):
# os.makedirs(self.test_data_dir)
args = Config(name="image-deblurring")
數據准備
- 數據集展示
- 數據增強
- 構造 dataset類
- 數據加載 dataloader
數據集展示
sample_idx = 1 # 樣本編號
blur_path = os.path.join(args.imgs_dir,f"blur/test{sample_idx}.png") # 模糊圖片
sharp_path = os.path.join(args.imgs_dir,f"sharp/test{sample_idx}.png") # 去模糊圖片
blur_img = plt.imread(blur_path)
sharp_img = plt.imread(sharp_path)
plt.figure(figsize=(10,4))
plt.subplot(121)
plt.imshow(blur_img)
plt.subplot(122)
plt.imshow(sharp_img)
plt.show()
數據增強
為了防止過擬合,需要對數據集進行數據增強,增強方式如下所示,對每一個輸入圖像,都將其進行隨機角度旋轉,旋轉的角度在 [0, 90, 180, 270] 中隨機選取。除此之外,考慮到圖像質量下降,對 HSV 顏色空間的飽和度乘以 0.8 到 1.2 內的隨機數
def augment(img_input, img_target):
degree = random.choice([0,90,180,270])
img_input = transforms.functional.rotate(img_input,degree)
img_target = transforms.functional.rotate(img_target,degree)
# color augmentation
img_input = transforms.functional.adjust_gamma(img_input,1)
img_target = transforms.functional.adjust_gamma(img_target,1)
sat_factor = 1 + (0.2 - 0.4* np.random.rand())
img_input = transforms.functional.adjust_saturation(img_input,sat_factor)
img_target = transforms.functional.adjust_saturation(img_target,sat_factor)
return img_input,img_target
img_input = Image.open(blur_path)
img_target = Image.open(sharp_path)
img_aug_input,img_aug_target = augment(img_input,img_target)
plt.figure(figsize=(10,5))
plt.subplot(121)
plt.imshow(img_aug_input)
plt.subplot(122)
plt.imshow(img_aug_target)
plt.show()
構造 dataset類
對每一個輸入圖像,對齊進行隨機裁剪,得到patch_size大小的輸入
def getPatch(img_input,img_target,patch_size):
w,h = img_input.size
p = patch_size
x = random.randrange(0,w-p +1)
y = random.randrange(0,h -p +1)
img_input = img_input.crop((x,y,x+p,y+p))
img_target = img_target.crop((x,y,x+p,y+p))
return img_input,img_target
class ImgMission(data.Dataset):
def __init__(self,data_dir, patch_size=256, is_train= False, multi=True):
super(ImgMission,self).__init__()
self.is_train = is_train #是否是訓練集
self.patch_size = patch_size # 訓練時 patch的尺寸
self.multi = multi # 是否采用多尺度因子,默認采用
self.sharp_file_paths = []
sub_folders = os.listdir(data_dir)
print(sub_folders)
for folder_name in sub_folders:
sharp_sub_folder = os.path.join(data_dir,folder_name,'sharp')
sharp_file_names = os.listdir(sharp_sub_folder)
# print(sharp_file_names)
for file_name in sharp_file_names:
sharp_file_path = os.path.join(sharp_sub_folder,file_name)
# print(sharp_file_path)
self.sharp_file_paths.append(sharp_file_path)
self.n_samples = len(self.sharp_file_paths)
def get_img_pair(self,idx):
sharp_file_path = self.sharp_file_paths[idx]
blur_file_path = sharp_file_path.replace("sharp","blur")
# print(blur_file_path)
img_input = Image.open(blur_file_path).convert('RGB')
img_target = Image.open(sharp_file_path).convert('RGB')
return img_input,img_target
def __getitem__(self,idx):
img_input,img_target = self.get_img_pair(idx)
if self.is_train:
img_input,img_target = getPatch(img_input,img_target, self.patch_size)
img_input,img_target= augment(img_input,img_target)
# 轉換為 tensor類型
input_b1 = transforms.ToTensor()(img_input)
target_s1 = transforms.ToTensor()(img_target)
H = input_b1.size()[1]
W= input_b1.size()[2]
if self.multi:
input_b1 = transforms.ToPILImage()(input_b1)
target_s1 = transforms.ToPILImage()(target_s1)
input_b2 = transforms.ToTensor()(transforms.Resize([int(H/2), int(W/2)])(input_b1))
input_b3 = transforms.ToTensor()(transforms.Resize([int(H/4), int(W/4)])(input_b1))
# 只對訓練集進行數據增強
if self.is_train:
target_s2 = transforms.ToTensor()(transforms.Resize([int(H/2), int(W/2)])(target_s1))
target_s3 = transforms.ToTensor()(transforms.Resize([int(H/4), int(W/4)])(target_s1))
else:
target_s2 = []
target_s3 = []
input_b1 = transforms.ToTensor()(input_b1)
target_s1 = transforms.ToTensor()(target_s1)
return {
'input_b1': input_b1, # 參照下文的網絡結構,輸入圖像的尺度 1
'input_b2': input_b2, # 輸入圖像的尺度 2
'input_b3': input_b3, # 輸入圖像的尺度 3
'target_s1': target_s1, # 目標圖像的尺度 1
'target_s2': target_s2, # 目標圖像的尺度 2
'target_s3': target_s3 # 目標圖像的尺度 3
}
else:
return {'input_b1': input_b1, 'target_s1': target_s1}
def __len__(self):
return self.n_samples
數據加載 dataloader
def get_dataset(data_dir,patch_size=None,
batch_size=1, n_threads=1,
is_train=False,multi=False):
# Dataset實例化
# print(data_dir)
# print(patch_size)
# print(is_train)
# print(multi)
dataset = ImgMission(data_dir,patch_size=patch_size,
is_train=is_train,multi=multi)
# print(dataset)
# 利用封裝好的 dataloader 接口定義訓練過程的迭代器
# 參數num_workers表示進程個數,在jupyter下改為0
dataloader = torch.utils.data.DataLoader(dataset,batch_size=batch_size,
drop_last=True, shuffle=is_train,
num_workers = 0)
return dataloader
- 將訓練時的dataloader實例化
data_loader = get_dataset(args.data_dir,
patch_size=args.patch_size,
batch_size= args.batch_size,
n_threads= args.n_threads,
is_train=True,
multi = args.multi
)
['GOPR0372_07_00', 'GOPR0372_07_01', 'GOPR0374_11_00', 'GOPR0374_11_01', 'GOPR0374_11_02', 'GOPR0374_11_03', 'GOPR0378_13_00', 'GOPR0379_11_00', 'GOPR0380_11_00', 'GOPR0384_11_01', 'GOPR0384_11_02', 'GOPR0384_11_03', 'GOPR0384_11_04', 'GOPR0385_11_00', 'GOPR0386_11_00', 'GOPR0477_11_00', 'GOPR0857_11_00', 'GOPR0868_11_01', 'GOPR0868_11_02', 'GOPR0871_11_01', 'GOPR0881_11_00', 'GOPR0884_11_00']
模型構建
- 模型介紹
- 模型定義
- 實例化模型
- 損失函數和優化器
CONV 表示卷積層,
ResBlock 表示殘差模塊,
Upconv 表示上采樣(也可以用反卷積代替)。
從圖中可以看出,該模型使用了 “multi-scale” 的結構,
在輸入和輸出部分都都采用了高斯金字塔(Gaussian pyramid)的形式(即對原圖像進行不同尺度的下采樣,從而獲得處於不同分辨率的圖像)
模型定義
- default_conv 是模型采用的默認卷積層,
- UpConv 用於上采樣卷積,
- ResidualBlock 是模型使用的殘差模塊,
- SingleScaleNet 是單個尺度網絡,
- MultiScaleNet 將幾個 SingleScaleNet 整合成了最終的多尺度網絡模型
具體作用
- default_conv : 網絡中默認采用的卷積層,定義之后,避免重復代碼
- UpConv : 上卷積,對應上圖中的 Up Conv,將圖像的尺度擴大,輸入到另一個單尺度網絡
- ResidualBlock : 殘差模塊,網絡模型中采用的殘差模塊,之所以采用殘差模塊,是因為網絡“只需要需要模糊圖像與去模糊圖像之間的差異即可”
- SingleScaleNet : 單尺度模型,一個尺度對應一個單尺度模型實例
- MultiScaleNet : 多尺度模型,將多個單尺度模型實例組合即可得到上圖所示的多尺度去模糊網絡
def default_conv(in_channels,out_channels, kernel_size, bias):
return nn.Conv2d(in_channels,
out_channels,
kernel_size,
padding=(kernel_size // 2),
bias=bias)
class UpConv(nn.Module):
def __init__(self):
super(UpConv, self).__init__()
self.body = nn.Sequential(default_conv(3,12,3,True),
nn.PixelShuffle(2),
nn.ReLU(inplace=True))
def forward(self,x):
return self.body(x)
class ResidualBlock(nn.Module):
def __init__(self,n_feats):
super(ResidualBlock,self).__init__()
modules_body = [
default_conv(n_feats, n_feats, 3, bias=True),
nn.ReLU(inplace=True),
default_conv(n_feats,n_feats,3,bias=True)
]
self.body = nn.Sequential(*modules_body)
def forward(self,x):
res= self.body(x)
res += x
return res
class SingleScaleNet(nn.Module):
def __init__(self,n_feats,n_resblocks, is_skip, n_channels=3):
super(SingleScaleNet, self).__init__()
self.is_skip = is_skip
modules_head = [
default_conv(n_channels,n_feats,5,bias=True),
nn.ReLU(inplace=True)
]
modules_body = [ResidualBlock(n_feats) for _ in range(n_resblocks)]
modules_tail = [default_conv(n_feats, 3,5,bias=True)]
self.head = nn.Sequential(*modules_head)
self.body = nn.Sequential(*modules_body)
self.tail = nn.Sequential(*modules_tail)
def forward(self,x):
x= self.head(x)
res= self.body(x)
if self.is_skip:
res += x
res = self.tail(res)
return res
class MultiScaleNet(nn.Module):
def __init__(self,n_feats, n_resblocks ,is_skip):
super(MultiScaleNet,self).__init__()
self.scale3_net = SingleScaleNet(n_feats,
n_resblocks,
is_skip,
n_channels=3)
self.upconv3 = UpConv()
self.scale2_net = SingleScaleNet(n_feats,
n_resblocks,
is_skip,
n_channels=6)
self.upconv2 = UpConv()
self.scale1_net = SingleScaleNet(n_feats,
n_resblocks,
is_skip,
n_channels=6)
def forward(self,mulscale_input):
input_b1, input_b2,input_b3 = mulscale_input
output_l3 = self.scale3_net(input_b3)
output_l3_up = self.upconv3(output_l3)
output_l2 = self.scale2_net(torch.cat((input_b2,output_l3_up),1))
output_l2_up = self.upconv2(output_l2)
output_l1 = self.scale2_net(torch.cat((input_b1,output_l2_up),1))
return output_l1,output_l2,output_l3
模型實例化
if args.multi:
my_model = MultiScaleNet(n_feats=args.n_feats,
n_resblocks = args.n_resblocks,
is_skip= args.skip)
else:
my_model = SingleScaleNet(n_feats=args.n_feats,
n_resblocks=args.n_resblocks,
is_skip = args.skip)
if torch.cuda.is_available():
my_model.cuda()
loss_function = nn.MSELoss().cuda()
else:
loss_function = nn.MSELoss()
optimizer = optim.Adam(my_model.parameters(),lr=args.lr)
print(my_model)
print(loss_function)
MultiScaleNet(
(scale3_net): SingleScaleNet(
(head): Sequential(
(0): Conv2d(3, 8, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
(1): ReLU(inplace=True)
)
(body): Sequential(
(0): ResidualBlock(
(body): Sequential(
(0): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU(inplace=True)
(2): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
)
(1): ResidualBlock(
(body): Sequential(
(0): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU(inplace=True)
(2): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
)
(2): ResidualBlock(
(body): Sequential(
(0): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU(inplace=True)
(2): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
)
)
(tail): Sequential(
(0): Conv2d(8, 3, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
)
)
(upconv3): UpConv(
(body): Sequential(
(0): Conv2d(3, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): PixelShuffle(upscale_factor=2)
(2): ReLU(inplace=True)
)
)
(scale2_net): SingleScaleNet(
(head): Sequential(
(0): Conv2d(6, 8, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
(1): ReLU(inplace=True)
)
(body): Sequential(
(0): ResidualBlock(
(body): Sequential(
(0): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU(inplace=True)
(2): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
)
(1): ResidualBlock(
(body): Sequential(
(0): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU(inplace=True)
(2): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
)
(2): ResidualBlock(
(body): Sequential(
(0): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU(inplace=True)
(2): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
)
)
(tail): Sequential(
(0): Conv2d(8, 3, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
)
)
(upconv2): UpConv(
(body): Sequential(
(0): Conv2d(3, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): PixelShuffle(upscale_factor=2)
(2): ReLU(inplace=True)
)
)
(scale1_net): SingleScaleNet(
(head): Sequential(
(0): Conv2d(6, 8, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
(1): ReLU(inplace=True)
)
(body): Sequential(
(0): ResidualBlock(
(body): Sequential(
(0): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU(inplace=True)
(2): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
)
(1): ResidualBlock(
(body): Sequential(
(0): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU(inplace=True)
(2): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
)
(2): ResidualBlock(
(body): Sequential(
(0): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU(inplace=True)
(2): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
)
)
(tail): Sequential(
(0): Conv2d(8, 3, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
)
)
)
MSELoss()
損失函數和優化器
- Adam 優化器,初始學習率為 lr,其相對於 SGD,更自動化,實際中需要調整的參數較少,但需要注意的是,其使用內存也比 SGD 要高。
- 損失函數使用最常見的均方損失函數(MSELoss):
其中 \(f^{\prime}(i,j)\) 和 \(f(i,j)\) 分別為模型輸出結果圖和非模糊圖上坐標為 \((i,j)\) 的像素,M,N分別表示圖片的長與寬。 - 具體的,本文所用的多尺度損失函數為:
\(f^{\prime}_k\) 和 \(f_k\) 分別表示第 \(k\) 個尺度上的輸出結果圖和非模糊圖。
模型訓練
- 訓練策略
- 訓練模型
- 訓練過程可視化
訓練策略
- 在模型訓練過程中,隨着訓練的進行,更新網絡參數的步進(學習率)應該越來越小,整體訓練過程應該滿足 “先粗調后細調”,這就是常說的學習率策略。
- 本次訓練采用的學習率優化策略為 lr_scheduler.StepLR,步進為 lr_step_size,學習率每隔 lr_step_size 個 epoch 乘以 lr_gamma
scheduler = lr_scheduler.StepLR(optimizer,args.lr_step_size,args.lr_gamma)
scheduler
<torch.optim.lr_scheduler.StepLR at 0x290378a0198>
訓練模型
在訓練開始之前,要先創建一個 SummaryWriter,用來記錄和可視化訓練過程
writer = SummaryWriter(os.path.join(args.save_dir,"temp/logs/"))
writer
<tensorboardX.writer.SummaryWriter at 0x290378d6e10>
- 在命令行運行 tensorboard --logdir=experiment/logs 來啟動tensorboard。
- 在訓練模型時,每訓練完一個 epoch 將模型的參數保存下來,防止訓練被意外中斷以及方便測試,如果需要不斷更新最新的一次訓練的參數,可以取消最后一行的注釋。
- 訓練過程中,使用 tqdm 的進度條來觀察訓練過程
bar_format = '{desc}{percentage:3.0f}% | [{elapsed}<{remaining},{rate_fmt}]' # 進度條格式
for epoch in range(args.epochs):
total_loss = 0
batch_bar = tqdm(data_loader, bar_format=bar_format) # 利用tqdm動態顯示訓練過程
for batch,images in enumerate(batch_bar):
my_model.train()
curr_batch = epoch * data_loader.__len__() + batch # 當前batch在整個訓練過程中的索引
input_b1 = images['input_b1'].to(device) # 原始輸入圖像
target_s1 = images['target_s1'].to(device) # 目標非模糊圖片
if args.multi:
input_b2 = images['input_b2'].to(device) # level-2 尺度
target_s2 = images['target_s2'].to(device)
input_b3 = images['input_b3'].to(device) # level-3 尺度
target_s3 = images['target_s3'].to(device)
output_l1, output_l2, output_l3 = my_model((input_b1,input_b2,input_b3))
# 損失函數
loss = (loss_function(output_l1,target_s1) + loss_function(output_l2,target_s2) + loss_function(output_l3, target_s3)) /3
else:
output_l1 = my_model(input_b1)
loss = loss_function(output_l1,target_s1)
my_model.zero_grad()
loss.backward() #反向傳播
optimizer.step() # 更新權值
total_loss += loss.item()
print_str = "|".join([
"epoch:%3d/%3d" % (epoch + 1, args.epochs),
"batch:%3d/%3d" % (batch + 1, data_loader.__len__()),
"loss:%.5f" % (loss.item()),
])
batch_bar.set_description(print_str,refresh=True) # 更新進度條
writer.add_scalar('train/batch_loss', loss.item(), curr_batch)
batch_bar.close()
scheduler.step() #調整學習率
loss = total_loss / (batch +1)
writer.add_scalar('train/batch_loss',loss,epoch)
torch.save(my_model.state_dict(),os.path.join(args.save_cp_dir, f'Epoch_{epoch}.pt')) # 保存每個 epoch 的參數
# torch.save(my_model.state_dict(),os.path.join(args.save_cp_dir, f'Epoch_lastest.pt')) # 保存最新的參數
模型評估
- 指標介紹
- 指標實現
指標介紹
為了評估模型的效果如何,我們通過計算
峰值信噪比(Peak Signal-to-Noise Ratio, PSNR),
結構相似性(Structural Similarity, SSIM)和
多尺度的 SSIM(Multi-Scale SSIM,MSSIM)三個指標來對結果進行分析
PSNR
PSNR 的定義如下:
其中,\(M A X_{I}\)表示圖像點顏色的最大數值,如果每個采樣點用 8 位表示,則最大數值為 255,\(MSE\)是兩個圖像之間的均方誤差。
PSNR值越大代表模糊圖像與參考圖像越接近,即去模糊效果越好。
SSIM
SSIM也是衡量兩幅圖片相似性的指標,其定義如下:
SSIM由模型輸出圖像 \(x\) 和參考圖像 \(y\) 之間的亮度對比($ l(\mathbf{x}, \mathbf{y})\()、對比度對比(\)c(\mathbf{x}, \mathbf{y})\()和結構對比(\)s(\mathbf{x}, \mathbf{y}) \()三部分組成,\)\alpha\(,\)\beta$ 和 \(\gamma\)是各自的權重因子,一般都取為 1:
其中,\(C_{1}\),\(C_{2}\)和\(C_{3}\)為常數,是為了避免分母接近於0時造成的不穩定性。\(\mu_{x}\) 和 \(\mu_{y}\) 分別為模型輸出圖像和參考圖像的均值。\(\sigma_{x}\) 和 \(\sigma_{y}\) 分別為模型輸出圖像和參考圖像的標准差。通常取 \(C1=(K1*L)^2\),\(C2=(K2*L)^2\),\(C3=C2/2\),一般地\(K1=0.01\), \(K2=0.03\), \(L=255\)( \(L\)是像素值的動態范圍,一般都取為255)。
輸出圖片和目標圖片的結構相似值越大,則表示相似性越高,圖像去模糊效果越好。
SSIM是一種符合人類直覺的圖像質量評價標准。從名字上我們不難發現,這種指標是在致力於向人類的真實感知看齊,詳細細節可以參考原論文
MSSIM
MSSIM相當於是在多個尺度上來進行SSIM指標的測試,相對於SSIM,其能更好的衡量圖像到觀看者的距離、像素信息密集程度等因素對觀看者給出的主觀評價所產生的影響。
論文中給出的一個例子是,觀看者給一個分辨率為1080p的較為模糊的畫面的評分可能會比分辨率為720p的較為銳利的畫面的評分高。因此在評價圖像質量的時候不考慮尺度因素可能會導致得出片面的結果。
MSSIM提出在不同分辨率(尺度)下多次計算結構相似度后綜合結果得到最終的評價數值。其計算過程框圖如下所示
MSSIM 的詳細細節可以參考原論文
指標實現
# class PSNR(nn.Module):
class PSNR(nn.Module):
def forward(self,img1,img2):
mse = ((img1 - img2) ** 2).mean() # 輸出圖像和參考圖像的 MSE
psnr = 10 * torch.log10(1.0 * 1.0 / (mse + 10 ** (-10)))
return psnr
# SSIM 和 MSSIM 的計算較為復雜,在這里,我們直接調用 pytorch-msssim 的接口來進行計算
ssim = pytorch_msssim.SSIM(data_range=1.0, size_average=True, channel=3)
mssim = pytorch_msssim.MS_SSIM(data_range=1.0, size_average=True, channel=3)
# 實例化
ssim = pytorch_msssim.SSIM(data_range=1.0, size_average=True, channel=3)
mssim = pytorch_msssim.MS_SSIM(data_range=1.0, size_average=True, channel=3)
psnr = PSNR()
模型預測
- 繪圖函數定義
- 模型加載
- 數據加載
- 模型預測與指標分析
- 結果展示與保存
繪圖函數定義
def plot_tensor(tensor):
if tensor.dim() == 4:
tensor = tensor.squeeze(0)
ret = transforms.ToPILImage()(tensor.squeeze(0))
plt.imshow(ret)
return
模型加載
訓練過程中我們保存了多個 checkpoint ,現在對其進行加載和測試。這里我們提供了兩種選擇 checkpoint 的方式,一種是選擇指定 checkpoint,一種是選擇最新的 checkpoint。在這里我們以最新的 checkpoint 為例進行測試
# option-A :測試指定epoch
# best_epoch = 100
# best_cp = f"{args.save_cp_dir}/Epoch_{best_epoch}.pt"
# option-B :測試最終epoch
# best_cp = f"{args.save_cp_dir}/Epoch_lastest.pt"
best_cp = f"{args.save_cp_dir}/Epoch_3.pt"
my_model.to("cuda").load_state_dict(torch.load(best_cp))
my_model = my_model.eval()
數據加載
# 由於此模型采用的是多尺度訓練,因此對於單張輸入圖像,需要對其進行處理,定義加載圖像的函數 load_images 為
def load_images(blur_img_path,multi):
target_s1 = None
sharp_img_path = blur_img_path.replace("blur","sharp")
if os.path.exists(sharp_img_path):
img_target = Image.open(sharp_img_path).convert('RGB')
target_s1 = transforms.ToTensor()(img_target).unsqueeze(0)
img_input = Image.open(blur_img_path).convert('RGB') # 轉換為image類型 方便進行resize
input_b1 = transforms.ToTensor()(img_input)
if multi:
H = input_b1.size()[1]
W = input_b1.size()[2]
input_b1 = transforms.ToPILImage()(input_b1)
input_b2 = transforms.ToTensor()(transforms.Resize([int(H/2), int(W/2)])(input_b1)).unsqueeze(0)
input_b3 = transforms.ToTensor()(transforms.Resize([int(H/4), int(W/4)])(input_b1)).unsqueeze(0)
input_b1 = transforms.ToTensor()(input_b1).unsqueeze(0)
return {'input_b1':input_b1, 'input_b2':input_b2, 'input_b3':input_b3, 'target_s1':target_s1}
else:
return {'input_b1':unsqueeze(0), 'target_s1':target_s1}
模型預測與指標分析
模型預測
#目錄一
# idx = 1
# blur_img_path = f"datasets/pictures/blur/test{idx}.png"
# 目錄二
idx='000001'
blur_img_path =f"datasets/test/GOPR0384_11_00/blur/{idx}.png"
item = load_images(blur_img_path,args.multi)
input_b1 = item['input_b1'].to(device)
input_b2 = item['input_b2'].to(device)
input_b3 = item['input_b3'].to(device)
target_s1 = item['target_s1'].to(device)
output_l1,_,_ = my_model((input_b1,input_b2,input_b3))
指標分析
原始模糊圖片與不模糊圖片之間的指標計算
blur_psnr = psnr(input_b1,target_s1)
blur_ssim = ssim(input_b1,target_s1)
blur_mssim = mssim(input_b1,target_s1)
print(f"原始模糊圖片:PSNR={blur_psnr.float()}, SSIM={blur_ssim.float()}, MSSIM={blur_mssim.float()}")
原始模糊圖片:PSNR=24.050003051757812, SSIM=0.716961145401001, MSSIM=0.840461015701294
去模糊圖片與不模糊的圖片之間的指標計算
output_psnr = psnr(output_l1,target_s1)
output_ssim = ssim(output_l1,target_s1)
output_mssim = mssim(output_l1,target_s1)
print(f"網絡輸出圖片:PSNR={output_psnr.float()}, SSIM={output_ssim.float()}, MSSIM={output_mssim.float()}")
網絡輸出圖片:PSNR=24.012224197387695, SSIM=0.7089502811431885, MSSIM=0.8413411974906921
結果展示
plt.figure(figsize=(6,10))
plt.subplot(311)
plot_tensor(input_b1)
plt.subplot(312)
plot_tensor(output_l1)
plt.subplot(313)
plot_tensor(target_s1)
# 將結果保存
save_name = blur_img_path.split("/")[-1]
save_path = os.path.join(args.save_dir,save_name)
save_img = transforms.ToPILImage()(output_l1.squeeze(0))
save_img.save(save_path)