殘差網絡ResNet(超詳細代碼解析) :你必須要知道backbone模塊成員之一


 

         本文主要貢獻代碼模塊(文末),在本文中對resnet進行了復現,是一份原始版本模塊,里面集成了權重文件pth的載入模塊(如函數:init_weights(self, pretrained=None)),layers的凍結模塊(如函數:_freeze_stages(self)),更是將其改寫成可讀性高的代碼,若你需要執行該模塊,可直接將其代碼模塊粘貼成.py文件即可。而理論模塊,並非本文重點,因此借鑒博客:https://zhuanlan.zhihu.com/p/42706477 ,我將不再說明:

 

注:本人也意在改寫更多backbones模塊,后續將會放入該github中,可供代碼下載:https://github.com/tangjunjun966/backbones

 

本文機構:1.基本原理;2. Resnet代碼復現;3.代碼運行結果展示

 

 1.基本原理

ResNet的作者何凱明也因此摘得CVPR2016最佳論文獎,當然何博士的成就遠不止於此,感興趣的可以去搜一下他后來的輝煌戰績。那么ResNet為什么會有如此優異的表現呢?其實ResNet是解決了深度CNN模型難訓練的問題,從圖2中可以看到14年的VGG才19層,而15年的ResNet多達152層,這在網絡深度完全不是一個量級上,所以如果是第一眼看這個圖的話,肯定會覺得ResNet是靠深度取勝。事實當然是這樣,但是ResNet還有架構上的trick,這才使得網絡的深度發揮出作用,這個trick就是殘差學習(Residual learning)。下面詳細講述ResNet的理論及實現。

深度網絡的退化問題

從經驗來看,網絡的深度對模型的性能至關重要,當增加網絡層數后,網絡可以進行更加復雜的特征模式的提取,所以當模型更深時理論上可以取得更好的結果,從圖2中也可以看出網絡越深而效果越好的一個實踐證據。但是更深的網絡其性能一定會更好嗎?實驗發現深度網絡出現了退化問題(Degradation problem):網絡深度增加時,網絡准確度出現飽和,甚至出現下降。這個現象可以在圖3中直觀看出來:56層的網絡比20層網絡效果還要差。這不會是過擬合問題,因為56層網絡的訓練誤差同樣高。我們知道深層網絡存在着梯度消失或者爆炸的問題,這使得深度學習模型很難訓練。但是現在已經存在一些技術手段如BatchNorm來緩解這個問題。因此,出現深度網絡的退化問題是非常令人詫異的。

圖3 20層與56層網絡在CIFAR-10上的誤差

殘差學習

深度網絡的退化問題至少說明深度網絡不容易訓練。但是我們考慮這樣一個事實:現在你有一個淺層網絡,你想通過向上堆積新層來建立深層網絡,一個極端情況是這些增加的層什么也不學習,僅僅復制淺層網絡的特征,即這樣新層是恆等映射(Identity mapping)。在這種情況下,深層網絡應該至少和淺層網絡性能一樣,也不應該出現退化現象。好吧,你不得不承認肯定是目前的訓練方法有問題,才使得深層網絡很難去找到一個好的參數。

這個有趣的假設讓何博士靈感爆發,他提出了殘差學習來解決退化問題。對於一個堆積層結構(幾層堆積而成)當輸入為 [公式] 時其學習到的特征記為 [公式] ,現在我們希望其可以學習到殘差 [公式] ,這樣其實原始的學習特征是 [公式] 。之所以這樣是因為殘差學習相比原始特征直接學習更容易。當殘差為0時,此時堆積層僅僅做了恆等映射,至少網絡性能不會下降,實際上殘差不會為0,這也會使得堆積層在輸入特征基礎上學習到新的特征,從而擁有更好的性能。殘差學習的結構如圖4所示。這有點類似與電路中的“短路”,所以是一種短路連接(shortcut connection)。

圖4 殘差學習單元

為什么殘差學習相對更容易,從直觀上看殘差學習需要學習的內容少,因為殘差一般會比較小,學習難度小點。不過我們可以從數學的角度來分析這個問題,首先殘差單元可以表示為:

[公式]

其中 [公式] 和 [公式] 分別表示的是第 [公式] 個殘差單元的輸入和輸出,注意每個殘差單元一般包含多層結構。 [公式] 是殘差函數,表示學習到的殘差,而 [公式] 表示恆等映射, [公式] 是ReLU激活函數。基於上式,我們求得從淺層 [公式] 到深層 [公式] 的學習特征為:

[公式]

利用鏈式規則,可以求得反向過程的梯度:

[公式]

式子的第一個因子 [公式] 表示的損失函數到達 [公式] 的梯度,小括號中的1表明短路機制可以無損地傳播梯度,而另外一項殘差梯度則需要經過帶有weights的層,梯度不是直接傳遞過來的。殘差梯度不會那么巧全為-1,而且就算其比較小,有1的存在也不會導致梯度消失。所以殘差學習會更容易。要注意上面的推導並不是嚴格的證明。

ResNet的網絡結構

ResNet網絡是參考了VGG19網絡,在其基礎上進行了修改,並通過短路機制加入了殘差單元,如圖5所示。變化主要體現在ResNet直接使用stride=2的卷積做下采樣,並且用global average pool層替換了全連接層。ResNet的一個重要設計原則是:當feature map大小降低一半時,feature map的數量增加一倍,這保持了網絡層的復雜度。從圖5中可以看到,ResNet相比普通網絡每兩層間增加了短路機制,這就形成了殘差學習,其中虛線表示feature map數量發生了改變。圖5展示的34-layer的ResNet,還可以構建更深的網絡如表1所示。從表中可以看到,對於18-layer和34-layer的ResNet,其進行的兩層間的殘差學習,當網絡更深時,其進行的是三層間的殘差學習,三層卷積核分別是1x1,3x3和1x1,一個值得注意的是隱含層的feature map數量是比較小的,並且是輸出feature map數量的1/4。

圖5 ResNet網絡結構圖表1 不同深度的ResNet

下面我們再分析一下殘差單元,ResNet使用兩種殘差單元,如圖6所示。左圖對應的是淺層網絡,而右圖對應的是深層網絡。對於短路連接,當輸入和輸出維度一致時,可以直接將輸入加到輸出上。但是當維度不一致時(對應的是維度增加一倍),這就不能直接相加。有兩種策略:(1)采用zero-padding增加維度,此時一般要先做一個downsamp,可以采用strde=2的pooling,這樣不會增加參數;(2)采用新的映射(projection shortcut),一般采用1x1的卷積,這樣會增加參數,也會增加計算量。短路連接除了直接使用恆等映射,當然都可以采用projection shortcut。

圖6 不同的殘差單元

作者對比18-layer和34-layer的網絡效果,如圖7所示。可以看到普通的網絡出現退化現象,但是ResNet很好的解決了退化問題。

圖7 18-layer和34-layer的網絡效果

最后展示一下ResNet網絡與其他網絡在ImageNet上的對比結果,如表2所示。可以看到ResNet-152其誤差降到了4.49%,當采用集成模型后,誤差可以降到3.57%。

表2 ResNet與其他網絡的對比結果

說一點關於殘差單元題外話,上面我們說到了短路連接的幾種處理方式,其實作者在文獻[2]中又對不同的殘差單元做了細致的分析與實驗,這里我們直接拋出最優的殘差結構,如圖8所示。改進前后一個明顯的變化是采用pre-activation,BN和ReLU都提前了。而且作者推薦短路連接采用恆等變換,這樣保證短路連接不會有阻礙。感興趣的可以去讀讀這篇文章。

 

 

 

 

 2.Resnet代碼復現

"""
@author: tangjun
@contact: 511026664@qq.com
@time: 2020/12/7 22:48
@desc: 殘差ackbone
"""

import torch.nn as nn
import torch
from collections import OrderedDict


def Conv(in_planes, out_planes, **kwargs):
    "3x3 convolution with padding"
    padding = kwargs.get('padding', 1)
    bias = kwargs.get('bias', False)
    stride = kwargs.get('stride', 1)
    kernel_size = kwargs.get('kernel_size', 3)
    out = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)
    return out


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = Conv(inplanes, planes, stride=stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = Conv(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes * 4)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class Resnet(nn.Module):
    arch_settings = {
        18: (BasicBlock, (2, 2, 2, 2)),
        34: (BasicBlock, (3, 4, 6, 3)),
        50: (Bottleneck, (3, 4, 6, 3)),
        101: (Bottleneck, (3, 4, 23, 3)),
        152: (Bottleneck, (3, 8, 36, 3))
    }

    def __init__(self, depth,
                 in_channels=None,
                 pretrained=None,

                frozen_stages=-1

                 # num_classes=None
                 ):
        self.inplanes = 64
        super(Resnet, self).__init__()

        self.inchannels = in_channels if in_channels is not None else 3  # 輸入通道
        # self.num_classes=num_classes
        self.block, layers = self.arch_settings[depth]
        self.frozen_stages=frozen_stages
        self.conv1 = nn.Conv2d(self.inchannels, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.layer1 = self._make_layer(self.block, 64, layers[0], stride=1)
        self.layer2 = self._make_layer(self.block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(self.block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(self.block, 512, layers[3], stride=2)

        # self.avgpool = nn.AvgPool2d(7)
        # self.fc = nn.Linear(512 * self.block.expansion, self.num_classes)
        self._freeze_stages()  # 凍結函數
    def _freeze_stages(self):
        if self.frozen_stages >= 0:
            self.norm1.eval()
            for m in [self.conv1, self.norm1]:
                for param in m.parameters():
                    param.requires_grad = False

        for i in range(1, self.frozen_stages + 1):
            m = getattr(self, 'layer{}'.format(i))
            m.eval()
            for param in m.parameters():
                param.requires_grad = False


    def init_weights(self, pretrained=None):
        if isinstance(pretrained, str):
            self.load_checkpoint(pretrained)
        elif pretrained is None:
            for m in self.modules():
                if isinstance(m, nn.Conv2d):
                    nn.init.kaiming_normal_(m.weight, a=0, mode='fan_out', nonlinearity='relu')
                    if hasattr(m, 'bias') and m.bias is not None: # m包含該屬性且m.bias非None # hasattr(對象,屬性)表示對象是否包含該屬性
                        nn.init.constant_(m.bias, 0)

                elif isinstance(m, nn.BatchNorm2d):
                    m.weight.data.fill_(1)
                    m.bias.data.zero_()

    def load_checkpoint(self, pretrained):

        checkpoint = torch.load(pretrained)
        if isinstance(checkpoint, OrderedDict):
            state_dict = checkpoint
        elif isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
            state_dict = checkpoint['state_dict']

        if list(state_dict.keys())[0].startswith('module.'):
            state_dict = {k[7:]: v for k, v in checkpoint['state_dict'].items()}

        unexpected_keys = []  # 保存checkpoint不在module中的key
        model_state = self.state_dict()  # 模型變量

        for name, param in state_dict.items():  # 循環遍歷pretrained的權重
            if name not in model_state:
                unexpected_keys.append(name)
                continue
            if isinstance(param, torch.nn.Parameter):
                # backwards compatibility for serialized parameters
                param = param.data

            try:
                model_state[name].copy_(param)  # 試圖賦值給模型
            except Exception:
                raise RuntimeError(
                    'While copying the parameter named {}, '
                    'whose dimensions in the model are {} not equal '
                    'whose dimensions in the checkpoint are {}.'.format(
                        name, model_state[name].size(), param.size()))
        missing_keys = set(model_state.keys()) - set(state_dict.keys())
        print('missing_keys:',missing_keys)
    def _make_layer(self, block, planes, num_blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, num_blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)

    def forward(self, x):
        outs = []
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        outs.append(x)
        x = self.layer2(x)
        outs.append(x)
        x = self.layer3(x)
        outs.append(x)
        x = self.layer4(x)
        outs.append(x)

        # x = self.avgpool(x)
        # x = x.view(x.size(0), -1)
        # x = self.fc(x)

        return tuple(outs)


if __name__ == '__main__':
    x = torch.ones((2, 3, 215, 215))
    model = Resnet(depth=50)

    model.init_weights(pretrained='./resnet50.pth')


    out = model(x)

    print(out)

 

 3.代碼運行結果展示

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM