說明
模型裁剪可分為兩種,一種是稀疏化裁剪,裁剪的粒度為值級別,一種是結構化裁剪,最常用的是通道裁剪。通道裁剪是減少輸出特征圖的通道數,對應的權值是卷積核的個數。
問題
通常模型裁剪的三個步驟是:1. 判斷網絡中不重要的通道 2. 刪減掉不重要的通道(一般不會立即刪,加mask等到評測時才開始刪) 3. 將模型導出,然后進行finetue恢復精度。
步驟1,2涉及到非常多的標准和方法,這里不去深究。但是到第3步的時候,怎么導出網絡,看似很簡單的問題,但是如果碰到resnet這種,是要花費時間研究細節的,而且目前還沒有人專門講這塊(實際上是個工程實現問題),下面來詳細說說。
以MobileNet為代表的模型
先考慮以mobilenet為代表的模型,mobilenet中包含了一系列塊,每塊中包含了深度可分離卷積和點卷積,然后整個模型就是一系列block塊的堆疊,在目前很多模型中都具有代表性。
首先我們只考慮了模型的11卷積,因為11卷積是最耗算力的,而33卷積的裁剪實際上沒有必要,意味可分離意味着將輸入特征圖的信息丟掉,與其丟掉,那不如在一開始就不去計算要丟掉的那部分,而不計算的那部分正是由前一層的11點卷積得到的,也就是說改變前一層的輸出通道,就等同於對當前的可分離卷積的裁剪。
然后問題就只剩下11卷積核的裁剪了,那么需要在模型初始化時設置不同的profile,來實現不同結構的模型裁剪模型,這里代碼中的例子是將第一個block中11卷積核的128通道裁剪為64通道,其他通道可依次次類推。
class MobileNet(nn.Module):
def __init__(self, n_class, profile='normal', channels=None):
self.channels = [32, 64, 104, 128, 248, 224, 456, 296, 456, 224, 104, 104, 208, 208]
if channels:
self.channsels = channels
super(MobileNet, self).__init__()
# original
if profile == 'normal':
in_planes = 32
cfg = [64, (128, 2), 128, (256, 2), 256, (512, 2), 512, 512, 512, 512, 512, (1024, 2), (1024,1)]
# 0.5 AMC
elif profile == '0.5flops':
in_planes = self.channels[0]
strides = [1, 2, 1, 2, 1, 2, 1,1,1,1,1, 2,1]
cfg = list(zip(self.channels[1:], strides))
else:
raise NotImplementedError
而在make_layers部分,需要判斷當前stride, 有三次stride,每次縮放一倍,默認stride都是1,當然也可以把stride全列舉出來,就不用判斷了。
def _make_layers(self, in_planes, cfg, layer):
layers = []
for x in cfg:
out_planes = x if isinstance(x, int) else x[0]
stride = 1 if isinstance(x, int) else x[1]
layers.append(layer(in_planes, out_planes, stride))
in_planes = out_planes
return nn.Sequential(*layers)
以Resnet50為代表的模型
前面解決了mobileNet的問題,其實也是一個基本網絡架構下的裁剪問題,但是目前的網絡往往具有復雜的連接,比如像resnet這樣,具有殘差結構的單元塊,這意味着殘差部分需要單獨處理。
在我壓縮完成得到壓縮配置之后,先寫了簡單版本的resnet_pruning版本,這是最朴素的思想:
def ResNet50_Pruning(**kwarg):
model = ResNet(Bottleneck_Pruning, [3,4,6,3], **kwarg)
p = 0
actions = [3, 56, 64, 64, 48, 240, 16, 64, 152, 32, 32, 152, 120, 104, 216, 368, 112, 32, 480, 112, 120, 504, 88, 104, 104, 240, 184, 368, 768, 200, 200, 640, 232, 192, 976, 248, 192, 760, 160, 208, 584, 208, 248, 968, 496, 224, 208, 416, 104, 104, 416, 104, 104, 416]
for i, m in enumerate(model.modules()):
if type(m) in (nn.Conv2d, nn.Linear):
if type(m) == nn.Conv2d and m.groups == m.in_channels: # depth-wise conv, buffer
continue
else:
if type(m) is nn.Linear:
m.in_features = actions[p]
else:
m.in_channels = actions[p]
m.out_channels = actions[p+1]
p += 1
return model
將每一層對應的actions都找到,然后令其channel都做出改變,這無疑是最直觀的寫法,但是由於CONV之后往往帶着BN層,當改完CONV之后,你發現BN還是原有的值,這就會使得維度不匹配。當然我有參考了Nvi-Lab的寫法,可以先建模型,然后獲取壓縮的action的裁剪通道,然后重建一個new_conv代替原有的conv,這樣寫也行,也是一種思路,不過我覺得這樣不優雅,而且容易漏東西。
然后我使用了另外一個思路,在建立模型的時候就建立一個裁剪之后的模型,但是由於resnet50有多個blocks,然后每個block中是一個瓶頸結構,這就需要你定位到哪一個block,以及該block中哪一個卷積,這樣就存在的一個缺陷就是需要全局的index來標記當前層是第幾層。具體的實現如下:
cfg = [3, 56, 64, 64, 48, 240, 16, 64, 152, 32, 32, 152, 120, 104, 216, 368, 112, 32, 480, 112, 120, 504, 88, 104, 104, 240, 184, 368, 768, 200, 200, 640, 232, 192, 976, 248, 192, 760, 160, 208, 584, 208, 248, 968, 496, 224, 208, 416, 104, 104, 416, 104, 104, 416]
shortcut = [1, 10, 22, 40]
block_nums = [3, 4, 6, 3]
def Sample(x, num):
np.random.seed(2019)
batch_size, channel_num, height, width = x.data.size()
channel_index = np.random.choice(channel_num, num)
x = x[:, channel_index, :, :]
return x
class Bottleneck(nn.Module):
def __init__(self, in_planes, planes, stride=1, offset=1):
super(Bottleneck, self).__init__()
# pw
self.conv1 = nn.Conv2d(cfg[offset], cfg[offset+1], kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(cfg[offset+1])
# dw
self.conv2 = nn.Conv2d(cfg[offset+1], cfg[offset+2], kernel_size=3, stride=stride, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(cfg[offset+2])
# pw
self.conv3 = nn.Conv2d(cfg[offset+2], cfg[offset+3], kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(cfg[offset+3])
self.shortcut = nn.Sequential()
if offset in shortcut:
p = shortcut.index(offset)
self.shortcut = nn.Sequential(
nn.Conv2d(cfg[offset], cfg[offset+3], kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(cfg[offset+3])
)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = F.relu(self.bn2(self.conv2(out)))
out = self.bn3(self.conv3(out))
s = self.shortcut(x)
if s.data.size() != out.data.size():
s = Sample(s, out.data.size(1))
out += s
out = F.relu(out)
return out
可以看到,對於殘差那一分支,采用的是sample采樣的方法來使得通道數與瓶頸結構相同,之所以不對瓶頸結構中的卷積結果進行采樣,是由於這樣可以盡可能多地保留輸入特征的信息。而瓶頸結構中多了offset參數用以標記當前的卷積的索引。