Pytorch實現UNet例子學習


參考:https://github.com/milesial/Pytorch-UNet

實現的是二值汽車圖像語義分割,包括 dense CRF 后處理.

使用python3,我的環境是python3.6

 

1.使用

1> 預測

1)查看所有的可用選項:

python predict.py -h

返回:

(deeplearning) userdeMBP:Pytorch-UNet-master user$ python predict.py -h
usage: predict.py [-h] [--model FILE] --input INPUT [INPUT ...]
                  [--output INPUT [INPUT ...]] [--cpu] [--viz] [--no-save]
                  [--no-crf] [--mask-threshold MASK_THRESHOLD] [--scale SCALE]

optional arguments:
  -h, --help            show this help message and exit
  --model FILE, -m FILE
                        Specify the file in which is stored the model (default
                        : 'MODEL.pth')  #指明使用的訓練好的模型文件,默認使用MODEL.pth
  --input INPUT [INPUT ...], -i INPUT [INPUT ...] #指明要進行預測的圖像文件,必須要有的值
                        filenames of input images
  --output INPUT [INPUT ...], -o INPUT [INPUT ...] #指明預測后生成的圖像文件的名字
                        filenames of ouput images
  --cpu, -c             Do not use the cuda version of the net #指明使用CPU,默認為false,即默認使用GPU
  --viz, -v             Visualize the images as they are processed #當圖像被處理時,將其可視化,默認為false,即不可以可視化
  --no-save, -n         Do not save the output masks #不存儲得到的預測圖像到某圖像文件中,和--viz結合使用,即可對預測結果可視化,但是不存儲結果,默認為false,即會保存結果
  --no-crf, -r          Do not use dense CRF postprocessing #指明不使用CRF對輸出進行后處理,默認為false,即使用CRF
  --mask-threshold MASK_THRESHOLD, -t MASK_THRESHOLD
                        Minimum probability value to consider a mask pixel #最小化考慮掩模像素為白色的概率值,默認為0.5
                        white
  --scale SCALE, -s SCALE
                        Scale factor for the input images #輸入圖像的比例因子,默認為0.5

 

2)預測單一圖片image.jpg並存儲結果到output.jpg的命令

python predict.py -i image.jpg -o output.jpg

測試一下:

(deeplearning) userdeMBP:Pytorch-UNet-master user$ python predict.py --cpu --viz -i image.jpg -o output.jpg
Loading model MODEL.pth
Using CPU version of the net, this may be very slow
Model loaded !

Predicting image image.jpg ...
/anaconda3/envs/deeplearning/lib/python3.6/site-packages/torch/nn/modules/upsampling.py:129: UserWarning: nn.Upsample is deprecated. Use nn.functional.interpolate instead.
  warnings.warn("nn.{} is deprecated. Use nn.functional.interpolate instead.".format(self.name))
/anaconda3/envs/deeplearning/lib/python3.6/site-packages/torch/nn/functional.py:1332: UserWarning: nn.functional.sigmoid is deprecated. Use torch.sigmoid instead.
  warnings.warn("nn.functional.sigmoid is deprecated. Use torch.sigmoid instead.")
Visualizing results for image image.jpg, close to continue ...

返回可視化圖片為:

關閉該可視化圖片命令就會運行結束:

Mask saved to output.jpg
(deeplearning) userdeMBP:Pytorch-UNet-master user$ 

並且在當前文件夾中生成名為output.jpg的文件,該圖為:

 

 

3)預測多張圖片並顯示,預測結果不存儲:

python predict.py -i image1.jpg image2.jpg --viz --no-save

測試:

先得到的是image1.jpg的可視化結果:

(deeplearning) userdeMBP:Pytorch-UNet-master user$ python predict.py -i image1.jpg image2.jpg --viz --no-save --cpu
Loading model MODEL.pth
Using CPU version of the net, this may be very slow
Model loaded !

Predicting image image1.jpg ...
/anaconda3/envs/deeplearning/lib/python3.6/site-packages/torch/nn/modules/upsampling.py:129: UserWarning: nn.Upsample is deprecated. Use nn.functional.interpolate instead.
  warnings.warn("nn.{} is deprecated. Use nn.functional.interpolate instead.".format(self.name))
/anaconda3/envs/deeplearning/lib/python3.6/site-packages/torch/nn/functional.py:1332: UserWarning: nn.functional.sigmoid is deprecated. Use torch.sigmoid instead.
  warnings.warn("nn.functional.sigmoid is deprecated. Use torch.sigmoid instead.")
Visualizing results for image image1.jpg, close to continue ...

圖為:

關閉這個后就會接着生成image2.jpg的可視化結果:

Predicting image image2.jpg ...
Visualizing results for image image2.jpg, close to continue ...

返回圖為:

這時候關閉該可視化服務就會結束了,並且沒有在本地保存生成的圖片

 

4)如果你的計算機只有CPU,即CPU-only版本,使用選項--cpu指定

5)你可以指定你使用的訓練好的模型文件,使用--mode MODEL.pth

 

6)如果使用上面的命令選項--no-crf:

(deeplearning) userdeMBP:Pytorch-UNet-master user$ python predict.py -i image1.jpg image2.jpg --viz --no-save --cpu --no-crf

返回的結果是:

還有:

可見crf后處理后,可以將一些不符合事實的判斷結果給剔除,使得結果更加精確

 

2〉訓練

python train.py -h

首先需要安裝模塊pydensecrf,實現CRF條件隨機場的模塊:

pip install pydensecrf
但是出錯:
pydensecrf/densecrf/include/Eigen/Core:22:10: fatal error: 'complex' file not found #include
<complex> ^~~~~~~~~ 1 warning and 1 error generated. error: command 'gcc' failed with exit status 1 ---------------------------------------- Failed building wheel for pydensecrf Running setup.py clean for pydensecrf Failed to build pydensecrf

解決辦法,參考https://github.com/lucasb-eyer/pydensecrf:

先安裝cython,需要0.22以上的版本:

(deeplearning) userdeMBP:Pytorch-UNet-master user$ pip install -U cython
Installing collected packages: cython
Successfully installed cython-0.29.7

然后從git安裝最新版本:

pip install git+https://github.com/lucasb-eyer/pydensecrf.git

但還是沒有成功

 

后面找到了新的方法,使用conda來安裝就成功了:

userdeMacBook-Pro:~ user$ conda install -n deeplearning -c conda-forge pydensecrf

-c指明從conda-forge下載模塊

conda-forge是可以安裝軟件包的附加渠道,使用該conda-forge頻道取代defaults

因為直接安裝conda install -n deeplearning pydensecrf找不到該模塊

 

這時候運行python train.py -h可見支持的選項的信息:

(deeplearning) userdeMBP:Pytorch-UNet-master user$ python train.py -h
Usage: train.py [options]

Options:
  -h, --help            show this help message and exit
  -e EPOCHS, --epochs=EPOCHS
                        number of epochs #指明迭代的次數
  -b BATCHSIZE, --batch-size=BATCHSIZE
                        batch size #圖像批處理的大小
  -l LR, --learning-rate=LR
                        learning rate #使用的學習率
  -g, --gpu             use cuda #使用GPU進行訓練
  -c LOAD, --load=LOAD  load file model #下載預訓練的文件,在該基礎上進行訓練
  -s SCALE, --scale=SCALE
                        downscaling factor of the images #圖像的縮小因子

 

3>代碼分析

1》unet定義網絡

unet/unet_parts.py

# sub-parts of the U-Net model

import torch
import torch.nn as nn
import torch.nn.functional as F

#實現左邊的橫向卷積
class double_conv(nn.Module): 
    '''(conv => BN => ReLU) * 2'''
    def __init__(self, in_ch, out_ch):
        super(double_conv, self).__init__()
        self.conv = nn.Sequential(
            #以第一層為例進行講解
            #輸入通道數in_ch,輸出通道數out_ch,卷積核設為kernal_size 3*3,padding為1,stride為1,dilation=1
            #所以圖中H*W能從572*572 變為 570*570,計算為570 = ((572 + 2*padding - dilation*(kernal_size-1) -1) / stride ) +1
            nn.Conv2d(in_ch, out_ch, 3, padding=1), 
            nn.BatchNorm2d(out_ch), #進行批標准化,在訓練時,該層計算每次輸入的均值與方差,並進行移動平均
            nn.ReLU(inplace=True), #激活函數
            nn.Conv2d(out_ch, out_ch, 3, padding=1), #再進行一次卷積,從570*570變為 568*568
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x = self.conv(x)
        return x

#實現左邊第一行的卷積
class inconv(nn.Module):# 
    def __init__(self, in_ch, out_ch):
        super(inconv, self).__init__()
        self.conv = double_conv(in_ch, out_ch) # 輸入通道數in_ch為3, 輸出通道數out_ch為64

    def forward(self, x):
        x = self.conv(x)
        return x

#實現左邊的向下池化操作,並完成另一層的卷積
class down(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(down, self).__init__()
        self.mpconv = nn.Sequential(
            nn.MaxPool2d(2),
            double_conv(in_ch, out_ch)
        )

    def forward(self, x):
        x = self.mpconv(x)
        return x

#實現右邊的向上的采樣操作,並完成該層相應的卷積操作
class up(nn.Module): 
    def __init__(self, in_ch, out_ch, bilinear=True):
        super(up, self).__init__()

        #  would be a nice idea if the upsampling could be learned too,
        #  but my machine do not have enough memory to handle all those weights
        if bilinear:#聲明使用的上采樣方法為bilinear——雙線性插值,默認使用這個值,計算方法為 floor(H*scale_factor),所以由28*28變為56*56
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        else: #否則就使用轉置卷積來實現上采樣,計算式子為 (Height-1)*stride - 2*padding -kernal_size +output_padding
            self.up = nn.ConvTranspose2d(in_ch//2, in_ch//2, 2, stride=2)

        self.conv = double_conv(in_ch, out_ch)

    def forward(self, x1, x2): #x2是左邊特征提取傳來的值
        #第一次上采樣返回56*56,但是還沒結束
        x1 = self.up(x1)
        
        # input is CHW, [0]是batch_size, [1]是通道數,更改了下,與源碼不同
        diffY = x1.size()[2] - x2.size()[2] #得到圖像x2與x1的H的差值,56-64=-8
        diffX = x1.size()[3] - x2.size()[3] #得到圖像x2與x1的W差值,56-64=-8

        #用第一次上采樣為例,即當上采樣后的結果大小與右邊的特征的結果大小不同時,通過填充來使x2的大小與x1相同
        #對圖像進行填充(-4,-4,-4,-4),左右上下都縮小4,所以最后使得64*64變為56*56
        x2 = F.pad(x2, (diffX // 2, diffX - diffX//2,
                        diffY // 2, diffY - diffY//2))
        
        # for padding issues, see 
        # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
        # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
        
        #將最后上采樣得到的值x1和左邊特征提取的值進行拼接,dim=1即在通道數上進行拼接,由512變為1024
        x = torch.cat([x2, x1], dim=1)
        x = self.conv(x)
        return x

#實現右邊的最高層的最右邊的卷積
class outconv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(outconv, self).__init__()
        self.conv = nn.Conv2d(in_ch, out_ch, 1)

    def forward(self, x):
        x = self.conv(x)
        return x

 

unet/unetmodel.py

# full assembly of the sub-parts to form the complete net

import torch.nn.functional as F

from .unet_parts import *

class UNet(nn.Module):
    def __init__(self, n_channels, n_classes): #圖片的通道數,1為灰度圖像,3為彩色圖像
        super(UNet, self).__init__()
        self.inc = inconv(n_channels, 64) #假設輸入通道數n_channels為3,輸出通道數為64
        self.down1 = down(64, 128)
        self.down2 = down(128, 256)
        self.down3 = down(256, 512)
        self.down4 = down(512, 512)
        self.up1 = up(1024, 256)
        self.up2 = up(512, 128)
        self.up3 = up(256, 64)
        self.up4 = up(128, 64)
        self.outc = outconv(64, n_classes)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        x = self.outc(x)
        return F.sigmoid(x) #進行二分類

 

 

2》utils

實現dense CRF的代碼utils/crf.py:

詳細可見pydensecrf的使用

#coding:utf-8
import numpy as np
import pydensecrf.densecrf as dcrf

def dense_crf(img, output_probs): #img為輸入的圖像,output_probs是經過網絡預測后得到的結果
    h = output_probs.shape[0] #高度
    w = output_probs.shape[1] #寬度

    output_probs = np.expand_dims(output_probs, 0)
    output_probs = np.append(1 - output_probs, output_probs, axis=0)

    d = dcrf.DenseCRF2D(w, h, 2) #NLABELS=2兩類標注,車和不是車
    U = -np.log(output_probs) #得到一元勢
    U = U.reshape((2, -1)) #NLABELS=2兩類標注
    U = np.ascontiguousarray(U) #返回一個地址連續的數組
    img = np.ascontiguousarray(img)

    d.setUnaryEnergy(U) #設置一元勢

    d.addPairwiseGaussian(sxy=20, compat=3) #設置二元勢中高斯情況的值
    d.addPairwiseBilateral(sxy=30, srgb=20, rgbim=img, compat=10)#設置二元勢眾雙邊情況的值

    Q = d.inference(5) #迭代5次推理
    Q = np.argmax(np.array(Q), axis=0).reshape((h, w)) #得列中最大值的索引結果

    return Q

 

utils/utils.py

import random
import numpy as np

#將圖像分成左右兩塊
def get_square(img, pos): 
    """Extract a left or a right square from ndarray shape : (H, W, C))"""
    h = img.shape[0]
    if pos == 0:
        return img[:, :h]
    else:
        return img[:, -h:]

def split_img_into_squares(img):
    return get_square(img, 0), get_square(img, 1)

#對圖像進行轉置,將(H, W, C)變為(C, H, W)
def hwc_to_chw(img):
    return np.transpose(img, axes=[2, 0, 1])

def resize_and_crop(pilimg, scale=0.5, final_height=None):
    w = pilimg.size[0] #得到圖片的寬
    h = pilimg.size[1]#得到圖片的高
    #默認scale為0.5,即將高和寬都縮小一半
    newW = int(w * scale) 
    newH = int(h * scale)

    #如果沒有指明希望得到的最終高度
    if not final_height:
        diff = 0
    else:
        diff = newH - final_height
    #重新設定圖片的大小
    img = pilimg.resize((newW, newH))
    #crop((left,upper,right,lower))函數,從圖像中提取出某個矩形大小的圖像。它接收一個四元素的元組作為參數,各元素為(left, upper, right, lower),坐標系統的原點(0, 0)是左上角
    #如果沒有設置final_height,其實就是取整個圖片
    #如果設置了final_height,就是取一個上下切掉diff // 2,最后高度為final_height的圖片
    img = img.crop((0, diff // 2, newW, newH - diff // 2))
    return np.array(img, dtype=np.float32)

def batch(iterable, batch_size):
    """批量處理列表"""
    b = []
    for i, t in enumerate(iterable):
        b.append(t)
        if (i + 1) % batch_size == 0:
            yield b
            b = []

    if len(b) > 0:
        yield b

#然后將數據分為訓練集和驗證集兩份
def split_train_val(dataset, val_percent=0.05):
    dataset = list(dataset)
    length = len(dataset) #得到數據集大小
    n = int(length * val_percent) #驗證集的數量
    random.shuffle(dataset) #將數據打亂
    return {'train': dataset[:-n], 'val': dataset[-n:]} 

#對像素值進行歸一化,由[0,255]變為[0,1]
def normalize(x):
    return x / 255

#將兩個圖片合並起來
def merge_masks(img1, img2, full_w):
    h = img1.shape[0]

    new = np.zeros((h, full_w), np.float32)
    new[:, :full_w // 2 + 1] = img1[:, :full_w // 2 + 1]
    new[:, full_w // 2 + 1:] = img2[:, -(full_w // 2 - 1):]

    return new


# credits to https://stackoverflow.com/users/6076729/manuel-lagunas
def rle_encode(mask_image):
    pixels = mask_image.flatten()
    # We avoid issues with '1' at the start or end (at the corners of
    # the original image) by setting those pixels to '0' explicitly.
    # We do not expect these to be non-zero for an accurate mask,
    # so this should not harm the score.
    pixels[0] = 0
    pixels[-1] = 0
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 2
    runs[1::2] = runs[1::2] - runs[:-1:2]
    return runs

 

utils/data_vis.py實現結果的可視化:

import matplotlib.pyplot as plt

def plot_img_and_mask(img, mask):
    fig = plt.figure()
    a = fig.add_subplot(1, 2, 1) #先是打印輸入的圖片
    a.set_title('Input image')
    plt.imshow(img)

    b = fig.add_subplot(1, 2, 2) #然后打印預測得到的結果圖片
    b.set_title('Output mask')
    plt.imshow(mask)
    plt.show()

 

utils/load.py

#
# load.py : utils on generators / lists of ids to transform from strings to
#           cropped images and masks

import os

import numpy as np
from PIL import Image

from .utils import resize_and_crop, get_square, normalize, hwc_to_chw


def get_ids(dir):
    """返回目錄中的id列表"""
    return (f[:-4] for f in os.listdir(dir)) #圖片名字的后4位為數字,能作為圖片id


def split_ids(ids, n=2):
    """將每個id拆分為n個,為每個id創建n個元組(id, k)"""
    #等價於for id in ids:
    #       for i in range(n):
    #           (id, i)
    #得到元祖列表[(id1,0),(id1,1),(id2,0),(id2,1),...,(idn,0),(idn,1)]
    #這樣的作用是后面會通過后面的0,1作為utils.py中get_square函數的pos參數,pos=0的取左邊的部分,pos=1的取右邊的部分
    return ((id, i)  for id in ids for i in range(n))


def to_cropped_imgs(ids, dir, suffix, scale):
    """從元組列表中返回經過剪裁的正確img"""
    for id, pos in ids:
        im = resize_and_crop(Image.open(dir + id + suffix), scale=scale) #重新設置圖片大小為原來的scale倍
        yield get_square(im, pos) #然后根據pos選擇圖片的左邊或右邊

def get_imgs_and_masks(ids, dir_img, dir_mask, scale):
    """返回所有組(img, mask)"""

    imgs = to_cropped_imgs(ids, dir_img, '.jpg', scale)

    # need to transform from HWC to CHW
    imgs_switched = map(hwc_to_chw, imgs) #對圖像進行轉置,將(H, W, C)變為(C, H, W)
    imgs_normalized = map(normalize, imgs_switched) #對像素值進行歸一化,由[0,255]變為[0,1]

    masks = to_cropped_imgs(ids, dir_mask, '_mask.gif', scale) #對圖像的結果也進行相同的處理

    return zip(imgs_normalized, masks) #並將兩個結果打包在一起


def get_full_img_and_mask(id, dir_img, dir_mask):
    im = Image.open(dir_img + id + '.jpg')
    mask = Image.open(dir_mask + id + '_mask.gif')
    return np.array(im), np.array(mask)

 

3》預測

predict.py使用訓練好的U-net網絡對圖像進行預測,使用dense CRF進行后處理:

#coding:utf-8
import argparse
import os

import numpy as np
import torch
import torch.nn.functional as F

from PIL import Image

from unet import UNet
from utils import resize_and_crop, normalize, split_img_into_squares, hwc_to_chw, merge_masks, dense_crf
from utils import plot_img_and_mask

from torchvision import transforms

def predict_img(net,
                full_img,
                scale_factor=0.5,
                out_threshold=0.5,
                use_dense_crf=True,
                use_gpu=False):

    net.eval() #進入網絡的驗證模式,這時網絡已經訓練好了
    img_height = full_img.size[1] #得到圖片的高
    img_width = full_img.size[0] #得到圖片的寬

    img = resize_and_crop(full_img, scale=scale_factor) #在utils文件夾的utils.py中定義的函數,重新定義圖像大小並進行切割,然后將圖像轉為數組np.array
    img = normalize(img) #對像素值進行歸一化,由[0,255]變為[0,1]

    left_square, right_square = split_img_into_squares(img)#將圖像分成左右兩塊,來分別進行判斷

    left_square = hwc_to_chw(left_square) #對圖像進行轉置,將(H, W, C)變為(C, H, W),便於后面計算
    right_square = hwc_to_chw(right_square)

    X_left = torch.from_numpy(left_square).unsqueeze(0) #將(C, H, W)變為(1, C, H, W),因為網絡中的輸入格式第一個還有一個batch_size的值
    X_right = torch.from_numpy(right_square).unsqueeze(0)
    
    if use_gpu:
        X_left = X_left.cuda()
        X_right = X_right.cuda()

    with torch.no_grad(): #不計算梯度
        output_left = net(X_left)
        output_right = net(X_right)

        left_probs = output_left.squeeze(0)
        right_probs = output_right.squeeze(0)

        tf = transforms.Compose(
            [
                transforms.ToPILImage(), #重新變成圖片
                transforms.Resize(img_height), #恢復原來的大小
                transforms.ToTensor() #然后再變成Tensor格式
            ]
        )
        
        left_probs = tf(left_probs.cpu())
        right_probs = tf(right_probs.cpu())

        left_mask_np = left_probs.squeeze().cpu().numpy()
        right_mask_np = right_probs.squeeze().cpu().numpy()

    full_mask = merge_masks(left_mask_np, right_mask_np, img_width)#將左右兩個拆分后的圖片合並起來

    #對得到的結果根據設置決定是否進行CRF處理
    if use_dense_crf:
        full_mask = dense_crf(np.array(full_img).astype(np.uint8), full_mask)

    return full_mask > out_threshold



def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model', '-m', default='MODEL.pth', #指明使用的訓練好的模型文件,默認使用MODEL.pth
                        metavar='FILE',
                        help="Specify the file in which is stored the model"
                             " (default : 'MODEL.pth')")
    parser.add_argument('--input', '-i', metavar='INPUT', nargs='+',  #指明要進行預測的圖像文件
                        help='filenames of input images', required=True)

    parser.add_argument('--output', '-o', metavar='INPUT', nargs='+', #指明預測后生成的圖像文件的名字
                        help='filenames of ouput images')
    parser.add_argument('--cpu', '-c', action='store_true', #指明使用CPU
                        help="Do not use the cuda version of the net",
                        default=False)
    parser.add_argument('--viz', '-v', action='store_true', 
                        help="Visualize the images as they are processed", #當圖像被處理時,將其可視化
                        default=False)
    parser.add_argument('--no-save', '-n', action='store_true', #不存儲得到的預測圖像到某圖像文件中,和--viz結合使用,即可對預測結果可視化,但是不存儲結果
                        help="Do not save the output masks",
                        default=False)
    parser.add_argument('--no-crf', '-r', action='store_true', #指明不使用CRF對輸出進行后處理
                        help="Do not use dense CRF postprocessing",
                        default=False)
    parser.add_argument('--mask-threshold', '-t', type=float, 
                        help="Minimum probability value to consider a mask pixel white", #最小概率值考慮掩模像素為白色
                        default=0.5)
    parser.add_argument('--scale', '-s', type=float,
                        help="Scale factor for the input images", #輸入圖像的比例因子
                        default=0.5)

    return parser.parse_args()

def get_output_filenames(args):#從輸入的選項args值中得到輸出文件名
    in_files = args.input 
    out_files = []

    if not args.output: #如果在選項中沒有指定輸出的圖片文件的名字,那么就會根據輸入圖片文件名,在其后面添加'_OUT'后綴來作為輸出圖片文件名
        for f in in_files:
            pathsplit = os.path.splitext(f) #將文件名和擴展名分開,pathsplit[0]是文件名,pathsplit[1]是擴展名
            out_files.append("{}_OUT{}".format(pathsplit[0], pathsplit[1])) #得到輸出圖片文件名
    elif len(in_files) != len(args.output): #如果設置了output名,查看input和output的數量是否相同,即如果input是兩張圖,那么設置的output也必須是兩個,否則報錯
        print("Error : Input files and output files are not of the same length")
        raise SystemExit()
    else:
        out_files = args.output

    return out_files

def mask_to_image(mask):
    return Image.fromarray((mask * 255).astype(np.uint8)) #從數組array轉成Image

if __name__ == "__main__":
    args = get_args() #得到輸入的選項設置的值
    in_files = args.input #得到輸入的圖像文件
    out_files = get_output_filenames(args) #從輸入的選項args值中得到輸出文件名

    net = UNet(n_channels=3, n_classes=1) #定義使用的model為UNet,調用在UNet文件夾下定義的unet_model.py,定義圖像的通道為3,即彩色圖像,判斷類型設為1種

    print("Loading model {}".format(args.model)) #指定使用的訓練好的model

    if not args.cpu: #指明使用GPU
        print("Using CUDA version of the net, prepare your GPU !")
        net.cuda()
        net.load_state_dict(torch.load(args.model))
    else: #否則使用CPU
        net.cpu()
        net.load_state_dict(torch.load(args.model, map_location='cpu'))
        print("Using CPU version of the net, this may be very slow")

    print("Model loaded !")

    for i, fn in enumerate(in_files): #對圖片進行預測
        print("\nPredicting image {} ...".format(fn))

        img = Image.open(fn)
        if img.size[0] < img.size[1]: #(W, H, C)
            print("Error: image height larger than the width")

        mask = predict_img(net=net,
                           full_img=img,
                           scale_factor=args.scale,
                           out_threshold=args.mask_threshold,
                           use_dense_crf= not args.no_crf,
                           use_gpu=not args.cpu)

        if args.viz: #可視化輸入的圖片和生成的預測圖片
            print("Visualizing results for image {}, close to continue ...".format(fn))
            plot_img_and_mask(img, mask)

        if not args.no_save:#設置為False,則保存
            out_fn = out_files[i]
            result = mask_to_image(mask) #從數組array轉成Image
            result.save(out_files[i]) #然后保存

            print("Mask saved to {}".format(out_files[i]))

 

4》訓練

import sys
import os
from optparse import OptionParser
import numpy as np

import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
from torch import optim

from eval import eval_net
from unet import UNet
from utils import get_ids, split_ids, split_train_val, get_imgs_and_masks, batch

def train_net(net,
              epochs=5,
              batch_size=1,
              lr=0.1,
              val_percent=0.05,
              save_cp=True,
              gpu=False,
              img_scale=0.5):

    dir_img = 'data/train/' #訓練圖像文件夾
    dir_mask = 'data/train_masks/' #圖像的結果文件夾
    dir_checkpoint = 'checkpoints/' #訓練好的網絡保存文件夾

    ids = get_ids(dir_img)#圖片名字的后4位為數字,能作為圖片id

    #得到元祖列表為[(id1,0),(id1,1),(id2,0),(id2,1),...,(idn,0),(idn,1)]
    #這樣的作用是后面重新設置生成器時會通過后面的0,1作為utils.py中get_square函數的pos參數,pos=0的取左邊的部分,pos=1的取右邊的部分
    #這樣圖片的數量就會變成2倍
    ids = split_ids(ids) 

    iddataset = split_train_val(ids, val_percent) #將數據分為訓練集和驗證集兩份

    print('''
    Starting training:
        Epochs: {}
        Batch size: {}
        Learning rate: {}
        Training size: {}
        Validation size: {}
        Checkpoints: {}
        CUDA: {}
    '''.format(epochs, batch_size, lr, len(iddataset['train']),
               len(iddataset['val']), str(save_cp), str(gpu)))

    N_train = len(iddataset['train']) #訓練集長度

    optimizer = optim.SGD(net.parameters(), #定義優化器
                          lr=lr,
                          momentum=0.9,
                          weight_decay=0.0005)

    criterion = nn.BCELoss()#損失函數

    for epoch in range(epochs): #開始訓練
        print('Starting epoch {}/{}.'.format(epoch + 1, epochs))
        net.train() #設置為訓練模式

        # reset the generators重新設置生成器
        # 對輸入圖片dir_img和結果圖片dir_mask進行相同的圖片處理,即縮小、裁剪、轉置、歸一化后,將兩個結合在一起,返回(imgs_normalized, masks)
        train = get_imgs_and_masks(iddataset['train'], dir_img, dir_mask, img_scale)
        val = get_imgs_and_masks(iddataset['val'], dir_img, dir_mask, img_scale)

        epoch_loss = 0

        for i, b in enumerate(batch(train, batch_size)):
            imgs = np.array([i[0] for i in b]).astype(np.float32) #得到輸入圖像數據
            true_masks = np.array([i[1] for i in b]) #得到圖像結果數據

            imgs = torch.from_numpy(imgs)
            true_masks = torch.from_numpy(true_masks)

            if gpu:
                imgs = imgs.cuda()
                true_masks = true_masks.cuda()

            masks_pred = net(imgs) #圖像輸入的網絡后得到結果masks_pred,結果為灰度圖像
            masks_probs_flat = masks_pred.view(-1) #將結果壓扁

            true_masks_flat = true_masks.view(-1) 

            loss = criterion(masks_probs_flat, true_masks_flat) #對兩個結果計算損失
            epoch_loss += loss.item()

            print('{0:.4f} --- loss: {1:.6f}'.format(i * batch_size / N_train, loss.item()))

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        print('Epoch finished ! Loss: {}'.format(epoch_loss / i)) #一次迭代后得到的平均損失

        if 1:
            val_dice = eval_net(net, val, gpu)
            print('Validation Dice Coeff: {}'.format(val_dice))

        if save_cp:
            torch.save(net.state_dict(),
                       dir_checkpoint + 'CP{}.pth'.format(epoch + 1))
            print('Checkpoint {} saved !'.format(epoch + 1))



def get_args():
    parser = OptionParser()
    parser.add_option('-e', '--epochs', dest='epochs', default=5, type='int', #設置迭代數
                      help='number of epochs')
    parser.add_option('-b', '--batch-size', dest='batchsize', default=10, #設置訓練批處理數
                      type='int', help='batch size')
    parser.add_option('-l', '--learning-rate', dest='lr', default=0.1, #設置學習率
                      type='float', help='learning rate')
    parser.add_option('-g', '--gpu', action='store_true', dest='gpu', #是否使用GPU,默認是不使用
                      default=False, help='use cuda')
    parser.add_option('-c', '--load', dest='load', #下載之前預訓練好的模型
                      default=False, help='load file model')
    parser.add_option('-s', '--scale', dest='scale', type='float', #圖像的縮小因子,用來重新設置圖片大小
                      default=0.5, help='downscaling factor of the images') 

    (options, args) = parser.parse_args()
    return options

if __name__ == '__main__':
    args = get_args() #得到設置的所有參數信息

    net = UNet(n_channels=3, n_classes=1)

    if args.load: #是否加載預先訓練好的模型
        net.load_state_dict(torch.load(args.load))
        print('Model loaded from {}'.format(args.load))

    if args.gpu: #是否使用GPU,設置為True,則使用
        net.cuda()
        # cudnn.benchmark = True # faster convolutions, but more memory

    try: #開始訓練
        train_net(net=net,
                  epochs=args.epochs,
                  batch_size=args.batchsize,
                  lr=args.lr,
                  gpu=args.gpu,
                  img_scale=args.scale)
    except KeyboardInterrupt: #如果鍵盤輸入ctrl+c停止,則會將結果保存在INTERRUPTED.pth中
        torch.save(net.state_dict(), 'INTERRUPTED.pth')
        print('Saved interrupt')
        try:
            sys.exit(0)
        except SystemExit:
            os._exit(0)

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM