本文主要貢獻代碼模塊(文末),在本文中對resnet進行了復現,是一份原始版本模塊,里面集成了權重文件pth的載入模塊(如函數:init_weights(self, pretrained=None)),layers的凍結模塊(如函數:_freeze_stages(self)),更是將其改寫成可讀性高的代碼,若你需要執行該模塊,可直接將其代碼模塊粘貼成.py文件即可。而理論模塊,並非本文重點,因此借鑒博客:https://zhuanlan.zhihu.com/p/42706477 ,我將不再說明:
注:本人也意在改寫更多backbones模塊,后續將會放入該github中,可供代碼下載:https://github.com/tangjunjun966/backbones
本文機構:1.基本原理;2. Resnet代碼復現;3.代碼運行結果展示
1.基本原理
ResNet的作者何凱明也因此摘得CVPR2016最佳論文獎,當然何博士的成就遠不止於此,感興趣的可以去搜一下他后來的輝煌戰績。那么ResNet為什么會有如此優異的表現呢?其實ResNet是解決了深度CNN模型難訓練的問題,從圖2中可以看到14年的VGG才19層,而15年的ResNet多達152層,這在網絡深度完全不是一個量級上,所以如果是第一眼看這個圖的話,肯定會覺得ResNet是靠深度取勝。事實當然是這樣,但是ResNet還有架構上的trick,這才使得網絡的深度發揮出作用,這個trick就是殘差學習(Residual learning)。下面詳細講述ResNet的理論及實現。
深度網絡的退化問題
從經驗來看,網絡的深度對模型的性能至關重要,當增加網絡層數后,網絡可以進行更加復雜的特征模式的提取,所以當模型更深時理論上可以取得更好的結果,從圖2中也可以看出網絡越深而效果越好的一個實踐證據。但是更深的網絡其性能一定會更好嗎?實驗發現深度網絡出現了退化問題(Degradation problem):網絡深度增加時,網絡准確度出現飽和,甚至出現下降。這個現象可以在圖3中直觀看出來:56層的網絡比20層網絡效果還要差。這不會是過擬合問題,因為56層網絡的訓練誤差同樣高。我們知道深層網絡存在着梯度消失或者爆炸的問題,這使得深度學習模型很難訓練。但是現在已經存在一些技術手段如BatchNorm來緩解這個問題。因此,出現深度網絡的退化問題是非常令人詫異的。
圖3 20層與56層網絡在CIFAR-10上的誤差
殘差學習
深度網絡的退化問題至少說明深度網絡不容易訓練。但是我們考慮這樣一個事實:現在你有一個淺層網絡,你想通過向上堆積新層來建立深層網絡,一個極端情況是這些增加的層什么也不學習,僅僅復制淺層網絡的特征,即這樣新層是恆等映射(Identity mapping)。在這種情況下,深層網絡應該至少和淺層網絡性能一樣,也不應該出現退化現象。好吧,你不得不承認肯定是目前的訓練方法有問題,才使得深層網絡很難去找到一個好的參數。
這個有趣的假設讓何博士靈感爆發,他提出了殘差學習來解決退化問題。對於一個堆積層結構(幾層堆積而成)當輸入為 時其學習到的特征記為
,現在我們希望其可以學習到殘差
,這樣其實原始的學習特征是
。之所以這樣是因為殘差學習相比原始特征直接學習更容易。當殘差為0時,此時堆積層僅僅做了恆等映射,至少網絡性能不會下降,實際上殘差不會為0,這也會使得堆積層在輸入特征基礎上學習到新的特征,從而擁有更好的性能。殘差學習的結構如圖4所示。這有點類似與電路中的“短路”,所以是一種短路連接(shortcut connection)。
圖4 殘差學習單元
為什么殘差學習相對更容易,從直觀上看殘差學習需要學習的內容少,因為殘差一般會比較小,學習難度小點。不過我們可以從數學的角度來分析這個問題,首先殘差單元可以表示為:
其中 和
分別表示的是第
個殘差單元的輸入和輸出,注意每個殘差單元一般包含多層結構。
是殘差函數,表示學習到的殘差,而
表示恆等映射,
是ReLU激活函數。基於上式,我們求得從淺層
到深層
的學習特征為:
利用鏈式規則,可以求得反向過程的梯度:
式子的第一個因子 表示的損失函數到達
的梯度,小括號中的1表明短路機制可以無損地傳播梯度,而另外一項殘差梯度則需要經過帶有weights的層,梯度不是直接傳遞過來的。殘差梯度不會那么巧全為-1,而且就算其比較小,有1的存在也不會導致梯度消失。所以殘差學習會更容易。要注意上面的推導並不是嚴格的證明。
ResNet的網絡結構
ResNet網絡是參考了VGG19網絡,在其基礎上進行了修改,並通過短路機制加入了殘差單元,如圖5所示。變化主要體現在ResNet直接使用stride=2的卷積做下采樣,並且用global average pool層替換了全連接層。ResNet的一個重要設計原則是:當feature map大小降低一半時,feature map的數量增加一倍,這保持了網絡層的復雜度。從圖5中可以看到,ResNet相比普通網絡每兩層間增加了短路機制,這就形成了殘差學習,其中虛線表示feature map數量發生了改變。圖5展示的34-layer的ResNet,還可以構建更深的網絡如表1所示。從表中可以看到,對於18-layer和34-layer的ResNet,其進行的兩層間的殘差學習,當網絡更深時,其進行的是三層間的殘差學習,三層卷積核分別是1x1,3x3和1x1,一個值得注意的是隱含層的feature map數量是比較小的,並且是輸出feature map數量的1/4。
圖5 ResNet網絡結構圖
表1 不同深度的ResNet
下面我們再分析一下殘差單元,ResNet使用兩種殘差單元,如圖6所示。左圖對應的是淺層網絡,而右圖對應的是深層網絡。對於短路連接,當輸入和輸出維度一致時,可以直接將輸入加到輸出上。但是當維度不一致時(對應的是維度增加一倍),這就不能直接相加。有兩種策略:(1)采用zero-padding增加維度,此時一般要先做一個downsamp,可以采用strde=2的pooling,這樣不會增加參數;(2)采用新的映射(projection shortcut),一般采用1x1的卷積,這樣會增加參數,也會增加計算量。短路連接除了直接使用恆等映射,當然都可以采用projection shortcut。
圖6 不同的殘差單元
作者對比18-layer和34-layer的網絡效果,如圖7所示。可以看到普通的網絡出現退化現象,但是ResNet很好的解決了退化問題。
圖7 18-layer和34-layer的網絡效果
最后展示一下ResNet網絡與其他網絡在ImageNet上的對比結果,如表2所示。可以看到ResNet-152其誤差降到了4.49%,當采用集成模型后,誤差可以降到3.57%。
表2 ResNet與其他網絡的對比結果
說一點關於殘差單元題外話,上面我們說到了短路連接的幾種處理方式,其實作者在文獻[2]中又對不同的殘差單元做了細致的分析與實驗,這里我們直接拋出最優的殘差結構,如圖8所示。改進前后一個明顯的變化是采用pre-activation,BN和ReLU都提前了。而且作者推薦短路連接采用恆等變換,這樣保證短路連接不會有阻礙。感興趣的可以去讀讀這篇文章。
2.Resnet代碼復現
""" @author: tangjun @contact: 511026664@qq.com @time: 2020/12/7 22:48 @desc: 殘差ackbone """ import torch.nn as nn import torch from collections import OrderedDict def Conv(in_planes, out_planes, **kwargs): "3x3 convolution with padding" padding = kwargs.get('padding', 1) bias = kwargs.get('bias', False) stride = kwargs.get('stride', 1) kernel_size = kwargs.get('kernel_size', 3) out = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias) return out class BasicBlock(nn.Module): expansion = 1 def __init__(self, inplanes, planes, stride=1, downsample=None): super(BasicBlock, self).__init__() self.conv1 = Conv(inplanes, planes, stride=stride) self.bn1 = nn.BatchNorm2d(planes) self.relu = nn.ReLU(inplace=True) self.conv2 = Conv(planes, planes) self.bn2 = nn.BatchNorm2d(planes) self.downsample = downsample self.stride = stride def forward(self, x): residual = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) if self.downsample is not None: residual = self.downsample(x) out += residual out = self.relu(out) return out class Bottleneck(nn.Module): expansion = 4 def __init__(self, inplanes, planes, stride=1, downsample=None): super(Bottleneck, self).__init__() self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) self.bn1 = nn.BatchNorm2d(planes) self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(planes) self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) self.bn3 = nn.BatchNorm2d(planes * 4) self.relu = nn.ReLU(inplace=True) self.downsample = downsample self.stride = stride def forward(self, x): residual = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) out = self.relu(out) out = self.conv3(out) out = self.bn3(out) if self.downsample is not None: residual = self.downsample(x) out += residual out = self.relu(out) return out class Resnet(nn.Module): arch_settings = { 18: (BasicBlock, (2, 2, 2, 2)), 34: (BasicBlock, (3, 4, 6, 3)), 50: (Bottleneck, (3, 4, 6, 3)), 101: (Bottleneck, (3, 4, 23, 3)), 152: (Bottleneck, (3, 8, 36, 3)) } def __init__(self, depth, in_channels=None, pretrained=None, frozen_stages=-1 # num_classes=None ): self.inplanes = 64 super(Resnet, self).__init__() self.inchannels = in_channels if in_channels is not None else 3 # 輸入通道 # self.num_classes=num_classes self.block, layers = self.arch_settings[depth] self.frozen_stages=frozen_stages self.conv1 = nn.Conv2d(self.inchannels, 64, kernel_size=7, stride=2, padding=3, bias=False) self.bn1 = nn.BatchNorm2d(64) self.relu = nn.ReLU(inplace=True) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.layer1 = self._make_layer(self.block, 64, layers[0], stride=1) self.layer2 = self._make_layer(self.block, 128, layers[1], stride=2) self.layer3 = self._make_layer(self.block, 256, layers[2], stride=2) self.layer4 = self._make_layer(self.block, 512, layers[3], stride=2) # self.avgpool = nn.AvgPool2d(7) # self.fc = nn.Linear(512 * self.block.expansion, self.num_classes) self._freeze_stages() # 凍結函數 def _freeze_stages(self): if self.frozen_stages >= 0: self.norm1.eval() for m in [self.conv1, self.norm1]: for param in m.parameters(): param.requires_grad = False for i in range(1, self.frozen_stages + 1): m = getattr(self, 'layer{}'.format(i)) m.eval() for param in m.parameters(): param.requires_grad = False def init_weights(self, pretrained=None): if isinstance(pretrained, str): self.load_checkpoint(pretrained) elif pretrained is None: for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, a=0, mode='fan_out', nonlinearity='relu') if hasattr(m, 'bias') and m.bias is not None: # m包含該屬性且m.bias非None # hasattr(對象,屬性)表示對象是否包含該屬性 nn.init.constant_(m.bias, 0) elif isinstance(m, nn.BatchNorm2d): m.weight.data.fill_(1) m.bias.data.zero_() def load_checkpoint(self, pretrained): checkpoint = torch.load(pretrained) if isinstance(checkpoint, OrderedDict): state_dict = checkpoint elif isinstance(checkpoint, dict) and 'state_dict' in checkpoint: state_dict = checkpoint['state_dict'] if list(state_dict.keys())[0].startswith('module.'): state_dict = {k[7:]: v for k, v in checkpoint['state_dict'].items()} unexpected_keys = [] # 保存checkpoint不在module中的key model_state = self.state_dict() # 模型變量 for name, param in state_dict.items(): # 循環遍歷pretrained的權重 if name not in model_state: unexpected_keys.append(name) continue if isinstance(param, torch.nn.Parameter): # backwards compatibility for serialized parameters param = param.data try: model_state[name].copy_(param) # 試圖賦值給模型 except Exception: raise RuntimeError( 'While copying the parameter named {}, ' 'whose dimensions in the model are {} not equal ' 'whose dimensions in the checkpoint are {}.'.format( name, model_state[name].size(), param.size())) missing_keys = set(model_state.keys()) - set(state_dict.keys()) print('missing_keys:',missing_keys) def _make_layer(self, block, planes, num_blocks, stride=1): downsample = None if stride != 1 or self.inplanes != planes * block.expansion: downsample = nn.Sequential( nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(planes * block.expansion), ) layers = [] layers.append(block(self.inplanes, planes, stride, downsample)) self.inplanes = planes * block.expansion for i in range(1, num_blocks): layers.append(block(self.inplanes, planes)) return nn.Sequential(*layers) def forward(self, x): outs = [] x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = self.maxpool(x) x = self.layer1(x) outs.append(x) x = self.layer2(x) outs.append(x) x = self.layer3(x) outs.append(x) x = self.layer4(x) outs.append(x) # x = self.avgpool(x) # x = x.view(x.size(0), -1) # x = self.fc(x) return tuple(outs) if __name__ == '__main__': x = torch.ones((2, 3, 215, 215)) model = Resnet(depth=50) model.init_weights(pretrained='./resnet50.pth') out = model(x) print(out)
3.代碼運行結果展示