Pytorch 剪枝操作實現
首先需要版本為 1.4 以上,
目前很多模型都取得了十分好的結果, 但是還是參數太多, 占得權重太大, 所以我們的目標是得到一個稀疏的子系數矩陣.
這個例子是基於 LeNet 的 Pytorch 實現的例子, 我們從 CNN 的角度來剪枝, 其實在全連接層與 RNN 的剪枝應該是類似, 首先導入一些必要的模塊
import torch
from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F
然后是 LeNet 的網絡結構, 不知道為什么這里的網絡結構是這樣的, 算出來輸入的圖像是 26x26 的,
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class LeNet(nn.Module):
def __init__(self):
super(LeNet, self).__init__()
# 1 input image channel, 6 output channels, 3x3 square conv kernel
self.conv1 = nn.Conv2d(1, 6, 3)
# 第一個卷積層, 輸出的向量維度是 6
self.conv2 = nn.Conv2d(6, 16, 3)
# 第二個卷積層, 輸出的向量維度是 16
self.fc1 = nn.Linear(16 * 5 * 5, 120) # 5x5 image dimension
# 最后將二維向量變成一維
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
# 2*2 的池化層
x = F.max_pool2d(F.relu(self.conv2(x)), 2)
# relu 激活函數層
x = x.view(-1, int(x.nelement() / x.shape[0]))
# 除以 batch_size 的大小, 將維度變成一
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
model = LeNet().to(device=device)
這時查看模型的參數:
module = model.conv1
print(list(module.named_parameters()))
此時參數包含矩陣的權值與偏置.
為了剪枝一個模型, 首先要在 torch.nn.utils.prune
中選擇一種剪枝方法, 或者使用子類 BasePruningMethod
實現自己的剪枝方法, 然后確定模型以及需要減去的參數, 最后,使用所選修剪技術所需的適當關鍵字參數,指定修剪參數. 在下面的例子中, 我們將要隨機減去 conv1 層中的 30% 的權重參數, module 是函數的第一個參數, name 使用的是參數的字符串標識, amount 表示剪枝的百分比.
prune.random_unstructured(module, name="weight", amount=0.3)
剪枝行為將 weight 參數名稱刪除, 並將其替代為新的參數名稱, weight_orig
, weight_orig存儲未修剪的張量版本. 也就是說 weight_orig 是原來的權重,
上述的剪枝方法會產生一個 mask 矩陣, 叫做 weight_mask , 存儲為一個 module buffer , 相當於一個 mask矩陣, 他的維度與 weight 的維度相同, 不同的是 mask 矩陣是一個 0/1 矩陣. 可以通過下面的函數查看 mask 矩陣:
print(list(module.named_buffers()))
剪枝之后的權重屬性 weight 不再是權重的集合, 而是 mask 矩陣與原始矩陣的結合, 所以不再是模型的一個 parameter, 而是一個 attribute.
最后,使用 PyTorch 的forward_pre_hooks在每次正向傳遞之前應用修剪。具體來說,如我們在此處所做的那樣,在剪枝模塊部分,它將為與之相關的每個要修剪的參數獲取一個forward_pre_hook。目前為止我們只修剪了名為weight的原始參數,因此將只存在一個 forward_pre_hook, 相當於沒有一個剪枝參數就有一個 forward_pre_hook.
除了對 weight 剪枝, 還可以對 bias 剪枝, 下面是通過 L1 范式剪去三個單元
prune.l1_unstructured(module, name="bias", amount=3)
# Prunes tensor corresponding to parameter called name in module by removing the specified amount of (currently unpruned) units with the lowest L1-norm.
Iterative Pruning
相同的參數在一個模型中可以被多次剪枝, 相當於把多個剪枝核序列化成一個剪枝核, 新的 mask 矩陣與舊的 mask 矩陣的結合使用 PruningContainer
中的 compute_mask
方法. 比如在上面的 module 的 weight 中, 我們除了隨機剪枝外還可以通過范式剪枝, 下面是個例子:
prune.ln_structured(module, name="weight", amount=0.5, n=2, dim=0)
# As we can verify, this will zero out all the connections corresponding to
# 50% (3 out of 6) of the channels, while preserving the action of the
# previous mask.
# 這里的 n 表示剪枝的范式, dim = 0, 表示參數矩陣的維度, 這里卷積層的 dim= 0, 就是核的個數
print(module.weight)
剪完之后, 核的個數變成原來的一半. mask 矩陣也會自動疊加.
還可以通過下面的方法查看我們使用了哪些方法剪枝, hook 記錄了某個 attribute 的剪枝方法:
for hook in module._forward_pre_hooks.values():
if hook._tensor_name == "weight": # select out the correct hook
break
print(list(hook)) # pruning history in the container
Serializing a pruned model
所有相關的張量,包括掩碼緩沖區和用於計算修剪的張量的原始參數,都存儲在模型的 state_dict 中,因此可以根據需要輕松地序列化和保存.
我們可以通過下面的方法查看模型中的權重參數:
>> print(model.state_dict().keys())
>> odict_keys(['conv1.weight_orig', 'conv1.bias_orig', 'conv1.weight_mask', 'conv1.bias_mask', 'conv2.weight', 'conv2.bias', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias', 'fc3.weight', 'fc3.bias'])
Remove pruning re-parametrization
注意, 這里的刪除剪枝的意思並不是真正的刪除, 還原到未剪枝的狀態. 舉個例子, 剪枝之后, 我們的參數 parameters 中的 weight 會變成, 'weight_orig', 而 weight 變成一個屬性, 他是 'weight_orig' 與 mask 矩陣結合后的結果, 那么
prune.remove(module, 'weight')
之后會發生什么呢?
print(list(module.named_parameters()))
('weight', Parameter containing:
tensor([[[[-0.0000, -0.0000, -0.0000],
[-0.0000, -0.0000, -0.0000],
[-0.0000, 0.0000, -0.0000]]],
.......
也就是說, weight 又變成了 parameters, 剪枝變成永久化.
Pruning multiple parameters in a model
多個參數, 多個網絡結構的剪枝,
new_model = LeNet()
for name, module in new_model.named_modules():
# prune 20% of connections in all 2D-conv layers
if isinstance(module, torch.nn.Conv2d):
prune.l1_unstructured(module, name='weight', amount=0.2)
# 將所有卷積層的權重減去 20%
# prune 40% of connections in all linear layers
elif isinstance(module, torch.nn.Linear):
prune.l1_unstructured(module, name='weight', amount=0.4)
# 將所有全連接層的權重減去 40%
print(dict(new_model.named_buffers()).keys()) # to verify that all masks exist
Global pruning
之前的剪枝我們都是針對每一層每一層的剪枝, 減去某一層權重的百分比, 對於全局剪枝就是將模型的參數看成一個整體, 減去一部分參數, 對於每一層減去的比例可能不同.
剪枝的方法可以通過下面的方法:
model = LeNet()
parameters_to_prune = (
(model.conv1, 'weight'),
(model.conv2, 'weight'),
(model.fc1, 'weight'),
(model.fc2, 'weight'),
(model.fc3, 'weight'),
)
prune.global_unstructured(
parameters_to_prune,
pruning_method=prune.L1Unstructured,
amount=0.2,
)
使用自定義的方法剪枝
要實現自己的修剪功能,可以通過將 BasePruningMethod 基類作為子類來擴展 nn.utils.prune
模塊,就像其他所有修剪方法一樣. 基類以及完成了下面的方法:
__call__
, apply_mask
, apply
, prune
, and remove
除了一些特殊的情況, 你不需要重寫這些方法以實現新的剪枝方法. 你需要實現的是:
__init__
構造器compute_mask
如何根據剪枝策略的邏輯為給定張量計算 mask- 需要說明是全局剪枝, 還是結構剪枝, 或者是非結構剪枝, 這決定了在迭代剪枝是如何結合 mask 矩陣, 換句話說,當剪枝需要剪枝的參數時,當前的剪枝策略應作用於參數的未剪枝部分。指定
PRUNING_TYPE
將啟用 PruningContainer 正確識別要修剪的參數的范圍.
比如說, 當我們希望剪枝一個張量中除了某一參數外的所有其他參數的時候, 或者說這個張量已經被部分剪枝的時候, 我們就需要設置: PRUNING_TYPE='unstructured'
因為他只是單獨作用與一層, 而不是一個單元或者通道(對應於'structured'), 也不是作用於整個參數(對應於'global')
class FooBarPruningMethod(prune.BasePruningMethod):
# 繼承自基類 BasePruningMethod
"""Prune every other entry in a tensor
"""
PRUNING_TYPE = 'unstructured'
# 類型為 unstructured 類型
def compute_mask(self, t, default_mask):
mask = default_mask.clone()
mask.view(-1)[::2] = 0
# 定義了 mask 矩陣的構成方法, 每兩個數字一個 0
return mask
然后給出一個調用的例子:
def foobar_unstructured(module, name):
"""Prunes tensor corresponding to parameter called `name` in `module`
by removing every other entry in the tensors.
Modifies module in place (and also return the modified module)
by:
1) adding a named buffer called `name+'_mask'` corresponding to the
binary mask applied to the parameter `name` by the pruning method.
The parameter `name` is replaced by its pruned version, while the
original (unpruned) parameter is stored in a new parameter named
`name+'_orig'`.
Args:
module (nn.Module): module containing the tensor to prune
name (string): parameter name within `module` on which pruning
will act.
Returns:
module (nn.Module): modified (i.e. pruned) version of the input
module
Examples:
>>> m = nn.Linear(3, 4)
>>> foobar_unstructured(m, name='bias')
"""
FooBarPruningMethod.apply(module, name)
return module
model = LeNet()
foobar_unstructured(model.fc3, name='bias')
print(model.fc3.bias_mask)
以上就是Pytorch 剪枝的主要方法, 其實對於復雜的剪枝方法, 只要在 compute_mask
設置特殊的 mask 構成方法就可以了.