Pruning Filters For Efficient ConvNets 剪枝代碼小結


The Code of Pruning Filters For Efficient ConvNets

1. 代碼參考

 https://github.com/tyui592/Pruning_filters_for_efficient_convnets

 其中主要是用VGG來進行在CIFAR100上的剪枝,理解#args是一些參數,比如VGG的權重路徑等信息配置

def prune_network(args, network=None): device = torch.device("cuda" if args.gpu_no >= 0 else "cpu") #配置GPU if network is None: network = VGG(args.vgg, args.data_set) #加載VGG模型 if args.load_path: check_point = torch.load(args.load_path) network.load_state_dict(check_point['state_dict']) # prune network
    network = prune_step(network, args.prune_layers, args.prune_channels, args.independent_prune_flag) network = network.to(device) print("-*-"*10 + "\n\tPrune network\n" + "-*-"*10) if args.retrain_flag: # update arguemtns for retraing pruned network
        args.epoch = args.retrain_epoch args.lr = args.retrain_lr args.lr_milestone = None # don't decay learning rate
        network = train_network(args, network) return network def prune_step(network, prune_layers, prune_channels, independent_prune_flag): network = network.cpu() #剪枝主要是cpu上進行操作 count = 0 # count for indexing 'prune_channels'
    conv_count = 1 # conv count for 'indexing_prune_layers'
    dim = 0 # 0: prune corresponding dim of filter weight [out_ch, in_ch, k1, k2],如果是0.表示輸入不變,將這個卷積核的輸出給去掉,同時一連串后面的bn,以及后面對應的對應的卷積核也需要剪掉,如果是1,表示要把前面的feature maps給剪掉。
    residue = None # residue is need to prune by 'independent strategy' #殘差
    for i in range(len(network.features)): if isinstance(network.features[i], torch.nn.Conv2d): if dim == 1:
#當前是1,表明上一層的filters被剪了,所以這一層要將inchannel的filters按照channel_index同時給剪掉
new_, residue
= get_new_conv(network.features[i], dim, channel_index, independent_prune_flag) network.features[i] = new_ dim ^= 1           #當前是0,表明我們這一層要把輸出的filters給剪掉。同時得到channel_index if 'conv%d'%conv_count in prune_layers: channel_index = get_channel_index(network.features[i].weight.data, prune_channels[count], residue) new_ = get_new_conv(network.features[i], dim, channel_index, independent_prune_flag) network.features[i] = new_ dim ^= 1 count += 1 else: residue = None conv_count += 1      # bn層也是有通道的,需要將bn層同樣做下處理。 elif dim == 1 and isinstance(network.features[i], torch.nn.BatchNorm2d): new_ = get_new_norm(network.features[i], channel_index) network.features[i] = new_ # update to check last conv layer pruned if 'conv13' in prune_layers: network.classifier[0] = get_new_linear(network.classifier[0], channel_index) return network def get_channel_index(kernel, num_elimination, residue=None):
#絕對值排序,按照最小值挑出前num_elimination個的下標 sum_of_kernel
= torch.sum(torch.abs(kernel.view(kernel.size(0), -1)), dim=1) if residue is not None: sum_of_kernel += torch.sum(torch.abs(residue.view(residue.size(0), -1)), dim=1) vals, args = torch.sort(sum_of_kernel) return args[:num_elimination].tolist() def index_remove(tensor, dim, index, removed=False):
  #根據index進行剪枝
if tensor.is_cuda: tensor = tensor.cpu() size_ = list(tensor.size()) new_size = tensor.size(dim) - len(index) size_[dim] = new_size select_index = list(set(range(tensor.size(dim))) - set(index)) new_tensor = torch.index_select(tensor, dim, torch.tensor(select_index)) if removed: return new_tensor, torch.index_select(tensor, dim, torch.tensor(index)) return new_tensor
def get_new_conv(conv, dim, channel_index, independent_prune_flag=False): if dim == 0: new_conv = torch.nn.Conv2d(in_channels=conv.in_channels, out_channels=int(conv.out_channels - len(channel_index)), kernel_size=conv.kernel_size, stride=conv.stride, padding=conv.padding, dilation=conv.dilation) new_conv.weight.data = index_remove(conv.weight.data, dim, channel_index) new_conv.bias.data = index_remove(conv.bias.data, dim, channel_index) return new_conv elif dim == 1: new_conv = torch.nn.Conv2d(in_channels=int(conv.in_channels - len(channel_index)), out_channels=conv.out_channels, kernel_size=conv.kernel_size, stride=conv.stride, padding=conv.padding, dilation=conv.dilation) new_weight = index_remove(conv.weight.data, dim, channel_index, independent_prune_flag) residue = None if independent_prune_flag: new_weight, residue = new_weight new_conv.weight.data = new_weight new_conv.bias.data = conv.bias.data return new_conv, residue def get_new_norm(norm, channel_index): new_norm = torch.nn.BatchNorm2d(num_features=int(norm.num_features - len(channel_index)), eps=norm.eps, momentum=norm.momentum, affine=norm.affine, track_running_stats=norm.track_running_stats) new_norm.weight.data = index_remove(norm.weight.data, 0, channel_index) new_norm.bias.data = index_remove(norm.bias.data, 0, channel_index) if norm.track_running_stats: new_norm.running_mean.data = index_remove(norm.running_mean.data, 0, channel_index) new_norm.running_var.data = index_remove(norm.running_var.data, 0, channel_index) return new_norm

#全連接因為filters數目的變化,也需要進行變化
def get_new_linear(linear, channel_index): new_linear = torch.nn.Linear(in_features=int(linear.in_features - len(channel_index)), out_features=linear.out_features, bias=linear.bias is not None) new_linear.weight.data = index_remove(linear.weight.data, 1, channel_index) new_linear.bias.data = linear.bias.data return new_linear

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM