【深度學習】基於Pytorch的ResNet實現


1. ResNet理論

論文:https://arxiv.org/pdf/1512.03385.pdf

殘差學習基本單元:

img

在ImageNet上的結果:

效果會隨着模型層數的提升而下降,當更深的網絡能夠開始收斂時,就會出現降級問題:隨着網絡深度的增加,准確度變得飽和(這可能不足為奇),然后迅速降級。

ResNet模型:

2. pytorch實現

2.1 基礎卷積

conv3$\times\(3 和conv1\)\times$1 基礎模塊

def conv3x3(in_channel, out_channel, stride=1, groups=1, dilation=1):
    return nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=stride, padding=dilation, groups=groups, bias=False, dilation=dilation)

def conv1x1(in_channel, out_channel, stride=1):
    return nn.Conv2d(in_channel, out_channel, kernel_size=1, bias=False)

參數解釋:

in_channel: 輸入的通道數目

out_channel:輸出的通道數目

stride, padding: 步長和補0

dilation: 空洞卷積中的參數

groups: 從輸入通道到輸出通道的阻塞連接數

feature size 計算:
output = (intput - filter_size + 2 x padding) / stride + 1

空洞卷積實際卷積核大小:

K = K + (K-1)x(R-1)
K 是原始卷積核大小
R 是空洞卷積參數的空洞率(普通卷積為1)

2.2 模塊

- resnet34
	- _resnet
		- ResNet
			- _make_layer
				- block 
					- Bottleneck
					- BasicBlock			

Bottlenect

class Bottleneck(nn.Module):
    expansion = 4
    __constants__ = ['downsample']

    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                 base_width=64, dilation=1, norm_layer=None):
        super(Bottleneck, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        width = int(planes * (base_width / 64.)) * groups
        # Both self.conv2 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv1x1(inplanes, width)
        self.bn1 = norm_layer(width)
        self.conv2 = conv3x3(width, width, stride, groups, dilation)
        self.bn2 = norm_layer(width)
        self.conv3 = conv1x1(width, planes * self.expansion)
        self.bn3 = norm_layer(planes * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)
        out = self.conv3(out)
        out = self.bn3(out)
        if self.downsample is not None:
            identity = self.downsample(x)
        out += identity
        out = self.relu(out)
        return out

BasicBlock

class BasicBlock(nn.Module):
    expansion = 1
    __constants__ = ['downsample']

    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                 base_width=64, dilation=1, norm_layer=None):
        super(BasicBlock, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        if groups != 1 or base_width != 64:
            raise ValueError('BasicBlock only supports groups=1 and base_width=64')
        if dilation > 1:
            raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = norm_layer(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = norm_layer(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        if self.downsample is not None:
            identity = self.downsample(x)
        out += identity
        out = self.relu(out)
        return out

2.3 使用ResNet模塊進行遷移學習

import torchvision.models as models
import torch.nn as nn

class RES18(nn.Module):
    def __init__(self):
        super(RES18, self).__init__()
        self.num_cls = settings.MAX_CAPTCHA*settings.ALL_CHAR_SET_LEN
        self.base = torchvision.models.resnet18(pretrained=False)
        self.base.fc = nn.Linear(self.base.fc.in_features, self.num_cls)
    def forward(self, x):
        out = self.base(x)
        return out

class RES34(nn.Module):
    def __init__(self):
        super(RES34, self).__init__()
        self.num_cls = settings.MAX_CAPTCHA*settings.ALL_CHAR_SET_LEN
        self.base = torchvision.models.resnet34(pretrained=False)
        self.base.fc = nn.Linear(self.base.fc.in_features, self.num_cls)
    def forward(self, x):
        out = self.base(x)
        return out

class RES50(nn.Module):
    def __init__(self):
        super(RES50, self).__init__()
        self.num_cls = settings.MAX_CAPTCHA*settings.ALL_CHAR_SET_LEN
        self.base = torchvision.models.resnet50(pretrained=False)
        self.base.fc = nn.Linear(self.base.fc.in_features, self.num_cls)
    def forward(self, x):
        out = self.base(x)
        return out

class RES101(nn.Module):
    def __init__(self):
        super(RES101, self).__init__()
        self.num_cls = settings.MAX_CAPTCHA*settings.ALL_CHAR_SET_LEN
        self.base = torchvision.models.resnet101(pretrained=False)
        self.base.fc = nn.Linear(self.base.fc.in_features, self.num_cls)
    def forward(self, x):
        out = self.base(x)
        return out

class RES152(nn.Module):
    def __init__(self):
        super(RES152, self).__init__()
        self.num_cls = settings.MAX_CAPTCHA*settings.ALL_CHAR_SET_LEN
        self.base = torchvision.models.resnet152(pretrained=False)
        self.base.fc = nn.Linear(self.base.fc.in_features, self.num_cls)
    def forward(self, x):
        out = self.base(x)
        return out

使用模塊直接生成一個類即可,比如訓練的時候:

cnn = RES101()
cnn.train() # 改為訓練模式
prediction = cnn(img) #進行預測

目前先寫這么多,看過了源碼以后感覺寫的很好,不僅僅有論文中最基礎的部分,還有一些額外的功能,模塊的組織也很整齊。

平時使用一般都進行遷移學習,使用的話可以把上述幾個類中pretrained=False參數改為True.

實戰篇:以上遷移學習代碼來自我的一個小項目,驗證碼識別,地址:https://github.com/pprp/captcha_identify.torch


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM