[論文理解] Learning Efficient Convolutional Networks through Network Slimming


Learning Efficient Convolutional Networks through Network Slimming

簡介

這是我看的第一篇模型壓縮方面的論文,應該也算比較出名的一篇吧,因為很早就對模型壓縮比較感興趣,所以抽了個時間看了一篇,代碼也自己實現了一下,覺得還是挺容易的。這篇文章就模型壓縮問題提出了一種剪枝針對BN層的剪枝方法,作者通過利用BN層的權重來評估輸入channel的score,通過對score進行threshold過濾到score低的channel,在連接的時候這些score太小的channel的神經元就不參與連接,然后逐層剪枝,就達到了壓縮效果。

就我個人而言,現在常用的attention mechanism我認為可以用來評估channel的score可以做一做文章,但是肯定是針對特定任務而言的,后面我會自己做一做實驗,利用attention機制來模型剪枝。

方法

本文的方法如圖所示,即

  1. 給定要保留層的比例,記下所有BN層大於該比例的權重
  2. 對模型先進行BN層的剪枝,即丟棄小於上面權重比例的參數
  3. 對模型進行卷積層剪枝(因為通常是卷積層后+BN,所以知道由前后的BN層可以知道卷積層權重size),對卷積層的size做匹配前后BN的對應channel元素丟棄的剪枝。
  4. 對FC層進行剪枝

感覺說不太清楚,但是一看代碼就全懂了。。

代碼

我自己實現了一下。

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import vgg19
from torchsummary import summary


class Net(nn.Module):
    def __init__(self):
        super(Net,self).__init__()
        self.convnet = nn.Sequential(
            nn.Conv2d(3,16,kernel_size = 3),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.Conv2d(16,32,kernel_size = 3),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(32,64,kernel_size = 3),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64,128,kernel_size = 3),
            nn.BatchNorm2d(128),
            nn.ReLU()
        )
        self.maxpool = nn.MaxPool2d(216)
        self.fc = nn.Linear(128,3)

    def forward(self,x):
        x = self.convnet(x)
        x = self.maxpool(x)
        x = x.view(-1,x.size(1))
        return self.fc(x)

if __name__ == "__main__":
    net = Net()
    net_new = Net()
    idxs = []
    idxs.append(range(3))
    for module in net.modules():
        if type(module) is nn.BatchNorm2d:
            weight = module.weight.data
            n = weight.size(0)
            y,idx = torch.sort(weight)
            n = int(0.8 * n) 
            idxs.append(idx[:n])
            #print(module.weight.data.size())
    i=1
    for module in net_new.modules():
        if type(module) is nn.Conv2d:
            weight = module.weight.data.clone()
            weight = weight[idxs[i],:,:,:]
            weight = weight[:,idxs[i-1],:,:]
            module.bias.data = module.bias.data[idxs[i]]
            module.weight.data = weight
        elif type(module) is nn.BatchNorm2d:
            weight = module.weight.data.clone()
            bias = module.bias.data.clone()
            running_mean = module.running_mean.data.clone()
            running_var = module.running_var.data.clone()
            
            weight = weight[idxs[i]]
            bias = bias[idxs[i]]
            running_mean = running_mean[idxs[i]]
            running_var = running_var[idxs[i]]

            module.weight.data = weight
            module.bias.data = bias
            module.running_var.data = running_var
            module.running_mean.data = running_mean
            i += 1
        elif type(module) is nn.Linear:
            #print(module.weight.data.size())
            module.weight.data = module.weight.data[:,idxs[-1]]
            
    summary(net_new,(3,224,224),device = "cpu")

'''
這是對vgg的剪枝例子,文章中說了對其他網絡的slimming例子
'''
import os
import argparse
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Variable
from torchvision import datasets, transforms
from torchvision.models import vgg19
from models import *


# Prune settings
parser = argparse.ArgumentParser(description='PyTorch Slimming CIFAR prune')
parser.add_argument('--dataset', type=str, default='cifar100',
                    help='training dataset (default: cifar10)')
parser.add_argument('--test-batch-size', type=int, default=256, metavar='N',
                    help='input batch size for testing (default: 256)')
parser.add_argument('--no-cuda', action='store_true', default=False,
                    help='disables CUDA training')
parser.add_argument('--depth', type=int, default=19,
                    help='depth of the vgg')
parser.add_argument('--percent', type=float, default=0.5,
                    help='scale sparse rate (default: 0.5)')
parser.add_argument('--model', default='', type=str, metavar='PATH',
                    help='path to the model (default: none)')
parser.add_argument('--save', default='', type=str, metavar='PATH',
                    help='path to save pruned model (default: none)')
args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()

if not os.path.exists(args.save):
    os.makedirs(args.save)

model = vgg19(dataset=args.dataset, depth=args.depth)
if args.cuda:
    model.cuda()

if args.model:
    if os.path.isfile(args.model):
        print("=> loading checkpoint '{}'".format(args.model))
        checkpoint = torch.load(args.model)
        args.start_epoch = checkpoint['epoch']
        best_prec1 = checkpoint['best_prec1']
        model.load_state_dict(checkpoint['state_dict'])
        print("=> loaded checkpoint '{}' (epoch {}) Prec1: {:f}"
              .format(args.model, checkpoint['epoch'], best_prec1))
    else:
        print("=> no checkpoint found at '{}'".format(args.resume))

print(model)
total = 0
for m in model.modules():# 遍歷vgg的每個module
    if isinstance(m, nn.BatchNorm2d): # 如果發現BN層
        total += m.weight.data.shape[0] # BN層的特征數目,total就是所有BN層的特征數目總和

bn = torch.zeros(total)
index = 0
for m in model.modules():
    if isinstance(m, nn.BatchNorm2d):
        size = m.weight.data.shape[0]
        bn[index:(index+size)] = m.weight.data.abs().clone()
        index += size # 把所有BN層的權重給CLONE下來

y, i = torch.sort(bn) # 這些權重排序
thre_index = int(total * args.percent) # 要保留的數量
thre = y[thre_index] # 最小的權重值

pruned = 0
cfg = []
cfg_mask = []
for k, m in enumerate(model.modules()):
    if isinstance(m, nn.BatchNorm2d):
        weight_copy = m.weight.data.abs().clone()
        mask = weight_copy.gt(thre).float().cuda()# 小於權重thre的為0,大於的為1
        pruned = pruned + mask.shape[0] - torch.sum(mask) # 被剪枝的權重的總數
        m.weight.data.mul_(mask) # 權重對應相乘
        m.bias.data.mul_(mask) # 偏置也對應相乘
        cfg.append(int(torch.sum(mask))) #第幾個batchnorm保留多少。
        cfg_mask.append(mask.clone()) # 第幾個batchnorm 保留的weight
        print('layer index: {:d} \t total channel: {:d} \t remaining channel: {:d}'.
            format(k, mask.shape[0], int(torch.sum(mask))))
    elif isinstance(m, nn.MaxPool2d):
        cfg.append('M')

pruned_ratio = pruned/total # 剪枝比例

print('Pre-processing Successful!')

# simple test model after Pre-processing prune (simple set BN scales to zeros)
def test(model):
    kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}
    if args.dataset == 'cifar10':
        test_loader = torch.utils.data.DataLoader(
            datasets.CIFAR10('./data.cifar10', train=False, transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])),
            batch_size=args.test_batch_size, shuffle=True, **kwargs)
    elif args.dataset == 'cifar100':
        test_loader = torch.utils.data.DataLoader(
            datasets.CIFAR100('./data.cifar100', train=False, transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])),
            batch_size=args.test_batch_size, shuffle=True, **kwargs)
    else:
        raise ValueError("No valid dataset is given.")
    model.eval()
    correct = 0
    for data, target in test_loader:
        if args.cuda:
            data, target = data.cuda(), target.cuda()
        data, target = Variable(data, volatile=True), Variable(target)
        output = model(data)
        pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability
        correct += pred.eq(target.data.view_as(pred)).cpu().sum()

    print('\nTest set: Accuracy: {}/{} ({:.1f}%)\n'.format(
        correct, len(test_loader.dataset), 100. * correct / len(test_loader.dataset)))
    return correct / float(len(test_loader.dataset))

acc = test(model)

# Make real prune
print(cfg)
newmodel = vgg(dataset=args.dataset, cfg=cfg)
if args.cuda:
    newmodel.cuda()
# torch.nelement() 可以統計張量的個數
num_parameters = sum([param.nelement() for param in newmodel.parameters()]) # 元素個數,比如對於張量shape為(20,3,3,3),那么他的元素個數就是四者乘積也就是20*27 = 540 
# 可以用來統計參數量 嘿嘿
savepath = os.path.join(args.save, "prune.txt")
with open(savepath, "w") as fp:
    fp.write("Configuration: \n"+str(cfg)+"\n")
    fp.write("Number of parameters: \n"+str(num_parameters)+"\n")
    fp.write("Test accuracy: \n"+str(acc))

layer_id_in_cfg = 0 # 第幾層
start_mask = torch.ones(3)
end_mask = cfg_mask[layer_id_in_cfg] # 
for [m0, m1] in zip(model.modules(), newmodel.modules()):
    if isinstance(m0, nn.BatchNorm2d):
        # np.where 返回的是所有滿足條件的數的索引,有多少個滿足條件的數就有多少個索引,絕對的索引
        idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy()))) # 大於0的所有數據的索引,squeeze變成向量
        if idx1.size == 1: # 只有一個要變成數組的1個
            idx1 = np.resize(idx1,(1,))
        m1.weight.data = m0.weight.data[idx1.tolist()].clone() # 用經過剪枝的替換原來的
        m1.bias.data = m0.bias.data[idx1.tolist()].clone()
        m1.running_mean = m0.running_mean[idx1.tolist()].clone()
        m1.running_var = m0.running_var[idx1.tolist()].clone()
        layer_id_in_cfg += 1 # 下一層
        start_mask = end_mask.clone() # 當前在處理的層的mask
        if layer_id_in_cfg < len(cfg_mask):  # do not change in Final FC
            end_mask = cfg_mask[layer_id_in_cfg]
    elif isinstance(m0, nn.Conv2d): # 對卷積層進行剪枝
        # 卷積后面會接bn
        idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy())))
        idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy())))
        print('In shape: {:d}, Out shape {:d}.'.format(idx0.size, idx1.size))
        if idx0.size == 1:
            idx0 = np.resize(idx0, (1,))
        if idx1.size == 1:
            idx1 = np.resize(idx1, (1,))
        w1 = m0.weight.data[:, idx0.tolist(), :, :].clone() # 這個剪枝牛B了。。
        w1 = w1[idx1.tolist(), :, :, :].clone() # 最終的權重矩陣
        m1.weight.data = w1.clone()
    elif isinstance(m0, nn.Linear):
        idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy())))
        if idx0.size == 1:
            idx0 = np.resize(idx0, (1,))
        m1.weight.data = m0.weight.data[:, idx0].clone()
        m1.bias.data = m0.bias.data.clone()

torch.save({'cfg': cfg, 'state_dict': newmodel.state_dict()}, os.path.join(args.save, 'pruned.pth.tar'))

print(newmodel)
model = newmodel
test(model)


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM