ESPCN單幀超分辨重構實現


ESPCN單幀超分辨重構實現

2019年ASC第3題:

(此處省略一堆背景介紹.)In this competition, the participant should design an algorithm using SOTA strategies like deep learning to do the 4x SR upscaling for images which were down-sampled with a bicubic kernel. For instance, the resolution of a 400x600 image after 4x upscaling is 1600x2400. The evaluation will be done in a perceptual-quality aware manner. The perceptual index (PI) defined in pirm2018 [4] will be used to calculate the quality of the reconstructed high-resolution images. Lower PI means higher quality of the reconstructed image. Ma and NIQE are two no-reference image quality measures [5-6].

ps:軟工實踐作業(改)


目錄:

背景:

那是一段昏暗的時光,寒假過半,我剛完成實驗室工作,而沒任何機器學習經驗的我作死認領了這個任務.英文很差的我沒有認真審題,草草讀完<終極算法>后,拿起<機器學習實戰>和點開網購的課程一頭扎入了tensorflow的坑.后來機緣巧合下發現pytorch十分適合我筆記本的硬件與環境配置,便開始自學pytorch;此處推薦'莫煩python'課程,簡單易懂,可快速學完.簡單過一遍課程后便開始審題.'只能用pytorch,不能用其他機器學習庫',想想都后怕,幸好當初沒選錯庫.

實現:

首先是大量搜索論文博客源碼,比較嘗試多種模塊后,最終選擇了espcn亞像素卷神經網絡,簡單而且看起來效果不錯的一個網絡.一開始,我的初始學習步長定為0.1,訓練效果極差,要么全黑,要么全白.首當其沖想到的是'死神經',所以我自作聰明地違背論文加了歸一話處理.效果還是不錯的,幾輪訓練下來可以得到信噪比不錯的4倍分辨率圖像.就這樣我拿着這個假的espcn神經網絡用游戲本的cuda訓練了幾個模型:cpu一直高溫警報,凜冬架着風扇,筆記本散熱器,含着淚跑完的.

代碼分析:

找源碼的過程很痛苦,但邊找邊分析更折磨人.現在看來那代碼也不是那么復雜.根據實際需求,我們只分析單幀超分辨重構實現部分.

源碼:github

分析思路

  1. 掃目錄,看readme.
  2. 分析得該項目有圖片超分辨率與視頻超分辨率兩個主要模塊,模塊間是否依賴未知.
  3. 確定思路,從test_image.py和train.py溯源,弄懂神經網絡模型,數據加載,損失函數.簡單看下訓練方法.並在不用第三方庫的情況下嘗試實現.

具體分析

  • test_image.py代碼
import argparse
import os
from os import listdir

import numpy as np
import torch
from PIL import Image
from torch.autograd import Variable
from torchvision.transforms import ToTensor
from tqdm import tqdm

from data_utils import is_image_file
from model import Net

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Test Super Resolution')
    parser.add_argument('--upscale_factor', default=3, type=int, help='super resolution upscale factor')
    parser.add_argument('--model_name', default='epoch_3_100.pt', type=str, help='super resolution model name')
    opt = parser.parse_args()

    UPSCALE_FACTOR = opt.upscale_factor
    MODEL_NAME = opt.model_name

    path = 'data/test/SRF_' + str(UPSCALE_FACTOR) + '/data/'
    images_name = [x for x in listdir(path) if is_image_file(x)]
    model = Net(upscale_factor=UPSCALE_FACTOR)
    if torch.cuda.is_available():
        model = model.cuda()
    model.load_state_dict(torch.load('epochs/' + MODEL_NAME))

    out_path = 'results/SRF_' + str(UPSCALE_FACTOR) + '/'
    if not os.path.exists(out_path):
        os.makedirs(out_path)
    for image_name in tqdm(images_name, desc='convert LR images to HR images'):

        img = Image.open(path + image_name).convert('YCbCr')
        y, cb, cr = img.split()
        image = Variable(ToTensor()(y)).view(1, -1, y.size[1], y.size[0])
        if torch.cuda.is_available():
            image = image.cuda()

        out = model(image)
        out = out.cpu()
        out_img_y = out.data[0].numpy()
        out_img_y *= 255.0
        out_img_y = out_img_y.clip(0, 255)
        out_img_y = Image.fromarray(np.uint8(out_img_y[0]), mode='L')
        out_img_cb = cb.resize(out_img_y.size, Image.BICUBIC)
        out_img_cr = cr.resize(out_img_y.size, Image.BICUBIC)
        out_img = Image.merge('YCbCr', [out_img_y, out_img_cb, out_img_cr]).convert('RGB')
        out_img.save(out_path + image_name)

第38行前都大多是模塊導入和參數分析功能的實現,先跳過

從第38行可知,該模塊在實現超分辨率前先對其做了灰度處理.不知道具體原因,故我采取直接對三通道圖片處理的方法實現.保留該思想.

第28行溯源可得神經網絡模型.

這里數據加載方法有點亂,總體思想難以接受.到train.py找吧.先看net

  • model.py
import torch.nn as nn
import torch.nn.functional as F


class Net(nn.Module):
    def __init__(self, upscale_factor):
        super(Net, self).__init__()

        self.conv1 = nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2))
        self.conv2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1))
        self.conv3 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1))
        self.conv4 = nn.Conv2d(32, 1 * (upscale_factor ** 2), (3, 3), (1, 1), (1, 1))
        self.pixel_shuffle = nn.PixelShuffle(upscale_factor)

    def forward(self, x):
        x = F.tanh(self.conv1(x))
        x = F.tanh(self.conv2(x))
        x = F.tanh(self.conv3(x))
        x = F.sigmoid(self.pixel_shuffle(self.conv4(x)))
        return x

直接亞像素卷神經,很好理解,三通道情況下也好實現.

比較特別的,這里隱藏層沒做歸一化處理.所以訓練步長不可過大.開始我用較大步長訓練,得到無法辨認的圖像,邊私下給他加了歸一化,違背了原始論文的思想.可現在回頭看,也許某些卷積層間加歸一化函數也許可以得到更好的訓練效果.

sigmoid一般在多分類中計算概率,這里用法也比較特別,學習.

  • train.py
import argparse

import torch
import torch.nn as nn
import torch.optim as optim
import torchnet as tnt
import torchvision.transforms as transforms
from torch.autograd import Variable
from torch.optim.lr_scheduler import MultiStepLR
from torch.utils.data import DataLoader
from torchnet.engine import Engine
from torchnet.logger import VisdomPlotLogger
from tqdm import tqdm

from data_utils import DatasetFromFolder
from model import Net
from psnrmeter import PSNRMeter


def processor(sample):
    data, target, training = sample
    data = Variable(data)
    target = Variable(target)
    if torch.cuda.is_available():
        data = data.cuda()
        target = target.cuda()

    output = model(data)
    loss = criterion(output, target)

    return loss, output


def on_sample(state):
    state['sample'].append(state['train'])


def reset_meters():
    meter_psnr.reset()
    meter_loss.reset()


def on_forward(state):
    meter_psnr.add(state['output'].data, state['sample'][1])
    meter_loss.add(state['loss'].data[0])


def on_start_epoch(state):
    reset_meters()
    scheduler.step()
    state['iterator'] = tqdm(state['iterator'])


def on_end_epoch(state):
    print('[Epoch %d] Train Loss: %.4f (PSNR: %.2f db)' % (
        state['epoch'], meter_loss.value()[0], meter_psnr.value()))

    train_loss_logger.log(state['epoch'], meter_loss.value()[0])
    train_psnr_logger.log(state['epoch'], meter_psnr.value())

    reset_meters()

    engine.test(processor, val_loader)
    val_loss_logger.log(state['epoch'], meter_loss.value()[0])
    val_psnr_logger.log(state['epoch'], meter_psnr.value())

    print('[Epoch %d] Val Loss: %.4f (PSNR: %.2f db)' % (
        state['epoch'], meter_loss.value()[0], meter_psnr.value()))

    torch.save(model.state_dict(), 'epochs/epoch_%d_%d.pt' % (UPSCALE_FACTOR, state['epoch']))


if __name__ == "__main__":

    parser = argparse.ArgumentParser(description='Train Super Resolution')
    parser.add_argument('--upscale_factor', default=3, type=int, help='super resolution upscale factor')
    parser.add_argument('--num_epochs', default=100, type=int, help='super resolution epochs number')
    opt = parser.parse_args()

    UPSCALE_FACTOR = opt.upscale_factor
    NUM_EPOCHS = opt.num_epochs

    train_set = DatasetFromFolder('data/train', upscale_factor=UPSCALE_FACTOR, input_transform=transforms.ToTensor(),
                                  target_transform=transforms.ToTensor())
    val_set = DatasetFromFolder('data/val', upscale_factor=UPSCALE_FACTOR, input_transform=transforms.ToTensor(),
                                target_transform=transforms.ToTensor())
    train_loader = DataLoader(dataset=train_set, num_workers=4, batch_size=64, shuffle=True)
    val_loader = DataLoader(dataset=val_set, num_workers=4, batch_size=64, shuffle=False)

    model = Net(upscale_factor=UPSCALE_FACTOR)
    criterion = nn.MSELoss()
    if torch.cuda.is_available():
        model = model.cuda()
        criterion = criterion.cuda()

    print('# parameters:', sum(param.numel() for param in model.parameters()))

    optimizer = optim.Adam(model.parameters(), lr=1e-2)
    scheduler = MultiStepLR(optimizer, milestones=[30, 80], gamma=0.1)

    engine = Engine()
    meter_loss = tnt.meter.AverageValueMeter()
    meter_psnr = PSNRMeter()

    train_loss_logger = VisdomPlotLogger('line', opts={'title': 'Train Loss'})
    train_psnr_logger = VisdomPlotLogger('line', opts={'title': 'Train PSNR'})
    val_loss_logger = VisdomPlotLogger('line', opts={'title': 'Val Loss'})
    val_psnr_logger = VisdomPlotLogger('line', opts={'title': 'Val PSNR'})

    engine.hooks['on_sample'] = on_sample
    engine.hooks['on_forward'] = on_forward
    engine.hooks['on_start_epoch'] = on_start_epoch
    engine.hooks['on_end_epoch'] = on_end_epoch

    engine.train(processor, train_loader, maxepoch=NUM_EPOCHS, optimizer=optimizer)

train的大部分信息已經在test中得到,這里主要學習損失函數和訓練數據的准備.

數據加載函數DatasetFromFolder,損失函數為均方差(應用了比賽不允許使用的庫,直接自己寫個均方差吧)

  • DatasetFromFolder
class DatasetFromFolder(Dataset):
    def __init__(self, dataset_dir, upscale_factor, input_transform=None, target_transform=None):
        super(DatasetFromFolder, self).__init__()
        self.image_dir = dataset_dir + '/SRF_' + str(upscale_factor) + '/data'
        self.target_dir = dataset_dir + '/SRF_' + str(upscale_factor) + '/target'
        self.image_filenames = [join(self.image_dir, x) for x in listdir(self.image_dir) if is_image_file(x)]
        self.target_filenames = [join(self.target_dir, x) for x in listdir(self.target_dir) if is_image_file(x)]
        self.input_transform = input_transform
        self.target_transform = target_transform

    def __getitem__(self, index):
        image, _, _ = Image.open(self.image_filenames[index]).convert('YCbCr').split()
        target, _, _ = Image.open(self.target_filenames[index]).convert('YCbCr').split()
        if self.input_transform:
            image = self.input_transform(image)
        if self.target_transform:
            target = self.target_transform(target)

        return image, target

    def __len__(self):
        return len(self.image_filenames)

繼承官方的Dataset類,實現__getitem__和__len__方法便可作為批數據!!!學習但我還是沒有對圖像進行灰度處理,采取了三通道下直接訓練的方法實現.

優化:

隨着訓練的進行和我對機器學習的進一步了解,我發現自己確實'to young,to simple'.學習過程中我突然有了去掉歸一化的想法.抱着忐忑的心情,我真的這么做了,然后把初始學習步長定為10的-4次方.對我來說簡直是奇跡,不僅訓練速度快了很多,訓練效果也比原來好.估計是之前步長過長,參數改動過大,造成生成像素值為0或1(三通道),微調后仍為0/1,損失函數得到的均值平方差不變,造成參數一直在一個錯誤的范圍內改動,神經網絡無法學習到有用信息,即成為'死神經'.故調小初始學習步長是不錯的選擇,在沒有多次歸一化的情況下能極大提高訓練速度.

效果:

使用非訓練集放大效果如下,左為代碼生成,右為原圖放大.

源碼:
espcn

ps:代碼已上傳,也許后面會加gan


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM