Pytorch 剪枝操作實現


Pytorch 剪枝操作實現

首先需要版本為 1.4 以上,

目前很多模型都取得了十分好的結果, 但是還是參數太多, 占得權重太大, 所以我們的目標是得到一個稀疏的子系數矩陣.

這個例子是基於 LeNet 的 Pytorch 實現的例子, 我們從 CNN 的角度來剪枝, 其實在全連接層與 RNN 的剪枝應該是類似, 首先導入一些必要的模塊

import torch
from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F

然后是 LeNet 的網絡結構, 不知道為什么這里的網絡結構是這樣的, 算出來輸入的圖像是 26x26 的,

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        # 1 input image channel, 6 output channels, 3x3 square conv kernel
        self.conv1 = nn.Conv2d(1, 6, 3)
        # 第一個卷積層, 輸出的向量維度是 6
        self.conv2 = nn.Conv2d(6, 16, 3)
        # 第二個卷積層, 輸出的向量維度是 16
        self.fc1 = nn.Linear(16 * 5 * 5, 120)  # 5x5 image dimension
        # 最后將二維向量變成一維
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        # 2*2 的池化層
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        # relu 激活函數層
        x = x.view(-1, int(x.nelement() / x.shape[0]))
        # 除以 batch_size 的大小, 將維度變成一
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

model = LeNet().to(device=device)

這時查看模型的參數:

module = model.conv1
print(list(module.named_parameters()))

此時參數包含矩陣的權值與偏置.

為了剪枝一個模型, 首先要在 torch.nn.utils.prune 中選擇一種剪枝方法, 或者使用子類 BasePruningMethod 實現自己的剪枝方法, 然后確定模型以及需要減去的參數, 最后,使用所選修剪技術所需的適當關鍵字參數,指定修剪參數. 在下面的例子中, 我們將要隨機減去 conv1 層中的 30% 的權重參數, module 是函數的第一個參數, name 使用的是參數的字符串標識, amount 表示剪枝的百分比.

prune.random_unstructured(module, name="weight", amount=0.3)

剪枝行為將 weight 參數名稱刪除, 並將其替代為新的參數名稱, weight_orig , weight_orig存儲未修剪的張量版本. 也就是說 weight_orig 是原來的權重,

上述的剪枝方法會產生一個 mask 矩陣, 叫做 weight_mask , 存儲為一個 module buffer , 相當於一個 mask矩陣, 他的維度與 weight 的維度相同, 不同的是 mask 矩陣是一個 0/1 矩陣. 可以通過下面的函數查看 mask 矩陣:

print(list(module.named_buffers()))

剪枝之后的權重屬性 weight 不再是權重的集合, 而是 mask 矩陣與原始矩陣的結合, 所以不再是模型的一個 parameter, 而是一個 attribute.

最后,使用 PyTorch 的forward_pre_hooks在每次正向傳遞之前應用修剪。具體來說,如我們在此處所做的那樣,在剪枝模塊部分,它將為與之相關的每個要修剪的參數獲取一個forward_pre_hook。目前為止我們只修剪了名為weight的原始參數,因此將只存在一個 forward_pre_hook, 相當於沒有一個剪枝參數就有一個 forward_pre_hook.

除了對 weight 剪枝, 還可以對 bias 剪枝, 下面是通過 L1 范式剪去三個單元

prune.l1_unstructured(module, name="bias", amount=3)
# Prunes tensor corresponding to parameter called name in module by removing the specified amount of (currently unpruned) units with the lowest L1-norm.

Iterative Pruning

相同的參數在一個模型中可以被多次剪枝, 相當於把多個剪枝核序列化成一個剪枝核, 新的 mask 矩陣與舊的 mask 矩陣的結合使用 PruningContainer 中的 compute_mask 方法. 比如在上面的 module 的 weight 中, 我們除了隨機剪枝外還可以通過范式剪枝, 下面是個例子:

prune.ln_structured(module, name="weight", amount=0.5, n=2, dim=0)
# As we can verify, this will zero out all the connections corresponding to 
# 50% (3 out of 6) of the channels, while preserving the action of the 
# previous mask.
# 這里的 n 表示剪枝的范式, dim = 0, 表示參數矩陣的維度, 這里卷積層的 dim= 0, 就是核的個數
print(module.weight)

剪完之后, 核的個數變成原來的一半. mask 矩陣也會自動疊加.

還可以通過下面的方法查看我們使用了哪些方法剪枝, hook 記錄了某個 attribute 的剪枝方法:

for hook in module._forward_pre_hooks.values():
    if hook._tensor_name == "weight":  # select out the correct hook
        break

print(list(hook))  # pruning history in the container

Serializing a pruned model

所有相關的張量,包括掩碼緩沖區和用於計算修剪的張量的原始參數,都存儲在模型的 state_dict 中,因此可以根據需要輕松地序列化和保存.

我們可以通過下面的方法查看模型中的權重參數:

>> print(model.state_dict().keys())
>> odict_keys(['conv1.weight_orig', 'conv1.bias_orig', 'conv1.weight_mask', 'conv1.bias_mask', 'conv2.weight', 'conv2.bias', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias', 'fc3.weight', 'fc3.bias'])

Remove pruning re-parametrization

注意, 這里的刪除剪枝的意思並不是真正的刪除, 還原到未剪枝的狀態. 舉個例子, 剪枝之后, 我們的參數 parameters 中的 weight 會變成, 'weight_orig', 而 weight 變成一個屬性, 他是 'weight_orig' 與 mask 矩陣結合后的結果, 那么

prune.remove(module, 'weight')

之后會發生什么呢?

print(list(module.named_parameters()))
('weight', Parameter containing:
tensor([[[[-0.0000, -0.0000, -0.0000],
          [-0.0000, -0.0000, -0.0000],
          [-0.0000,  0.0000, -0.0000]]],
.......

也就是說, weight 又變成了 parameters, 剪枝變成永久化.

Pruning multiple parameters in a model

多個參數, 多個網絡結構的剪枝,

new_model = LeNet()
for name, module in new_model.named_modules():
    # prune 20% of connections in all 2D-conv layers 
    if isinstance(module, torch.nn.Conv2d):
        prune.l1_unstructured(module, name='weight', amount=0.2)
        # 將所有卷積層的權重減去 20%
    # prune 40% of connections in all linear layers 
    elif isinstance(module, torch.nn.Linear):
        prune.l1_unstructured(module, name='weight', amount=0.4)
        # 將所有全連接層的權重減去 40%

print(dict(new_model.named_buffers()).keys())  # to verify that all masks exist

Global pruning

之前的剪枝我們都是針對每一層每一層的剪枝, 減去某一層權重的百分比, 對於全局剪枝就是將模型的參數看成一個整體, 減去一部分參數, 對於每一層減去的比例可能不同.

剪枝的方法可以通過下面的方法:

model = LeNet()

parameters_to_prune = (
    (model.conv1, 'weight'),
    (model.conv2, 'weight'),
    (model.fc1, 'weight'),
    (model.fc2, 'weight'),
    (model.fc3, 'weight'),
)

prune.global_unstructured(
    parameters_to_prune,
    pruning_method=prune.L1Unstructured,
    amount=0.2,
)

使用自定義的方法剪枝

要實現自己的修剪功能,可以通過將 BasePruningMethod 基類作為子類來擴展 nn.utils.prune 模塊,就像其他所有修剪方法一樣. 基類以及完成了下面的方法:

__call__, apply_mask, apply, prune, and remove

除了一些特殊的情況, 你不需要重寫這些方法以實現新的剪枝方法. 你需要實現的是:

  1. __init__ 構造器
  2. compute_mask 如何根據剪枝策略的邏輯為給定張量計算 mask
  3. 需要說明是全局剪枝, 還是結構剪枝, 或者是非結構剪枝, 這決定了在迭代剪枝是如何結合 mask 矩陣, 換句話說,當剪枝需要剪枝的參數時,當前的剪枝策略應作用於參數的未剪枝部分。指定 PRUNING_TYPE 將啟用 PruningContainer 正確識別要修剪的參數的范圍.

比如說, 當我們希望剪枝一個張量中除了某一參數外的所有其他參數的時候, 或者說這個張量已經被部分剪枝的時候, 我們就需要設置: PRUNING_TYPE='unstructured' 因為他只是單獨作用與一層, 而不是一個單元或者通道(對應於'structured'), 也不是作用於整個參數(對應於'global')

class FooBarPruningMethod(prune.BasePruningMethod):
    # 繼承自基類 BasePruningMethod
    """Prune every other entry in a tensor
    """
    PRUNING_TYPE = 'unstructured'
    # 類型為 unstructured 類型

    def compute_mask(self, t, default_mask):
        mask = default_mask.clone()
        mask.view(-1)[::2] = 0 
        # 定義了 mask 矩陣的構成方法, 每兩個數字一個 0
        return mask

然后給出一個調用的例子:

def foobar_unstructured(module, name):
    """Prunes tensor corresponding to parameter called `name` in `module`
    by removing every other entry in the tensors.
    Modifies module in place (and also return the modified module) 
    by:
    1) adding a named buffer called `name+'_mask'` corresponding to the 
    binary mask applied to the parameter `name` by the pruning method.
    The parameter `name` is replaced by its pruned version, while the 
    original (unpruned) parameter is stored in a new parameter named 
    `name+'_orig'`.

    Args:
        module (nn.Module): module containing the tensor to prune
        name (string): parameter name within `module` on which pruning
                will act.

    Returns:
        module (nn.Module): modified (i.e. pruned) version of the input
            module
    
    Examples:
        >>> m = nn.Linear(3, 4)
        >>> foobar_unstructured(m, name='bias')
    """
    FooBarPruningMethod.apply(module, name)
    return module

model = LeNet()
foobar_unstructured(model.fc3, name='bias')

print(model.fc3.bias_mask)

以上就是Pytorch 剪枝的主要方法, 其實對於復雜的剪枝方法, 只要在 compute_mask 設置特殊的 mask 構成方法就可以了.


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM