Pytorch推出fx,量化起飛


(本文首發於公眾號,沒事來逛逛)

Pytorch1.8 發布后,官方推出一個 torch.fx 的工具包,可以動態地對 forward 流程進行跟蹤,並構建出模型的圖結構。這個新特性能帶來什么功能呢?別的不說,就模型量化這一塊,煉丹師們有福了。

其實早在三年前 pytorch1.3 發布的時候,官方就推出了量化功能。但我覺得當時官方重點是在后端的量化推理引擎(FBGEMM 和 QNNPACK)上,對於 pytorch 前端的接口設計很粗糙。用過 pytorch 量化的同學都知道,這個量化接口實在是太麻煩、太粗糙、太暴力了。官方又把這個第一代的量化方式稱為 Eager Mode Quantization。我后面會用一個例子來展示這種方式有多傻x。

而隨着 fx 的推出,由於可以動態地 trace 出網絡的圖結構,因此就可以針對網絡模型動態地添加一些量化節點。官方又稱這種新的量化方式為 FX Graph Mode Quantization。上一張官方的圖來對比一下這兩種方式的優缺點:

我總結一下這張圖,Eager Mode Quantization 需要手工修改網絡代碼,並對很多節點進行替換,而 FX Graph Mode Quantization 則大大提高了自動化的能力。

現在就用代碼實際對比一下二者的差異。

首先,定義一個簡單的網絡:

class Net(nn.Module):

    def __init__(self, num_channels=1):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(num_channels, 40, 3, 1)
        self.conv2 = nn.Conv2d(40, 40, 3, 1)
        self.fc = nn.Linear(5*5*40, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.reshape(-1, 5*5*40)
        x = self.fc(x)
        return x

這個網絡的寫法應該是很常見的,結構非常簡單。pytorch 這種動態圖有一個很好的地方,就是可以在 forward 函數中天馬星空構造「電路圖」,比如 Functional 這些函數模塊可以隨意調用,而不需要在 init 函數里面事先定義,再比如可以隨時加入 if、for 等邏輯控制語句。這就是動態圖區別於靜態圖的地方。但這種好處的代價就是,我們很難獲取網絡的圖結構。

下面我們就看看 Eager 模式下的量化怎么操作。

看過我之前量化系列教程的讀者應該知道,模型量化需要在原網絡節點中插入一些偽量化節點,或者把一些 Module 或者 Function 替換成量化的形式。對於 Eager 模式,由於它只會對 init 函數里面定義的模塊進行替換,因此,如果有一些 op 沒有在 init 中定義,但又在 forward 中用到了(比如上面代碼的 F.relu),那就涼涼了。

因此,上面這段網絡代碼是沒法直接用 Eager 模式量化的,需要重新寫成下面這種形式:

class NetQuant(nn.Module):

    def __init__(self, num_channels=1):
        super(NetQuant, self).__init__()
        self.conv1 = nn.Conv2d(num_channels, 40, 3, 1)
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(40, 40, 3, 1)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(2, 2)
        self.fc = nn.Linear(5*5*40, 10)

        self.quant = torch.quantization.QuantStub()
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, x):
        x = self.quant(x)
        x = self.relu1(self.conv1(x))
        x = self.pool1(x)
        x = self.relu2(self.conv2(x))
        x = self.pool2(x)
        x = x.reshape(-1, 5*5*40)
        x = self.fc(x)
        x = self.dequant(x)
        return x

這樣一來,除了 ConvLinear 這些含有參數的 Module 外,ReLUMaxPool2d 也在 init 中定義了,Eager 模式才能進行處理。

這還沒完,由於有些節點是要做 fuse 之后才能量化的(比如:Conv + ReLU),因此我們需要手動指定這些層進行合並:

model = NetQuant()
model.qconfig = torch.quantization.get_default_qconfig('fbgemm')
modules_to_fuse = [['conv1', 'relu1'], ['conv2', 'relu2']]  # 指定合並layer的名字
model_fused = torch.quantization.fuse_modules(model, modules_to_fuse)
model_prepared = torch.quantization.prepare(model_fused)
post_training_quantize(model_prepared, train_loader)   # 這一步是做后訓練量化
model_int8 = torch.quantization.convert(model_prepared)

這一套流程下來不可謂不繁瑣,而且,這只是一個相當簡單的網絡,遇上復雜的,或者是別人天馬行空寫完丟給你量化的網絡,分分鍾可以去世。pytorch 這套設計直接勸退了很多想上手量化的同學,我很早之前看到這些操作也是一點上手的欲望都沒有。

那這套新的 Graph 模式的量化又是怎樣的呢?

由於 FX 可以自動跟蹤 forward 里面的代碼,因此它是真正記錄了網絡里面的每個節點,在 fuse 和動態插入量化節點方面,要比 Eager 模式強太多。還是前面那個模型代碼,我們不需要對網絡做修改,直接讓 FX 幫我們自動修改網絡即可:

from torch.quantization import get_default_qconfig, quantize_jit
from torch.quantization.quantize_fx import prepare_fx, convert_fx
model = Net()  
qconfig = get_default_qconfig("fbgemm")
qconfig_dict = {"": qconfig}
model_prepared = prepare_fx(model, qconfig_dict)
post_training_quantize(model_prepared, train_loader)      # 這一步是做后訓練量化
model_int8 = convert_fx(model_prepared)

對比一下前面 Eager 模式的流程,有沒有感覺自己又可以了。

目前 FX 這個新工具包還在優化中,很多功能並不完善。比如,如果 forward 代碼中出現了 if 和 for 等控制語句,它依然還是解析不了,這個時候就需要你把 if 還有 for 語句手動拆解掉。但相比起之前的流程,已經是一個巨大的進步了。而且,有了這個圖結構,很多后訓練量化的算法也可以更加方便的操作(很多 PTQ 的算法需要針對針對網絡的拓撲結構優化)。除此以外,像 NAS 等模型結構搜索之類的算法,也可以更加方便的進行。

總的來說,pytorch 推出的這個新特性實在是極大彌補了動態圖的先天不足。之前一直考慮針對 pytorch 做一些離線量化的工具,但由於它的圖結構很難獲取,因此一直難以入手(ONNX 和 jit 這些工具對量化支持又不夠)。現在有了 fx,感覺可以加油起飛了。希望官方再接再厲,不要機毀人亡。

歡迎關注我的公眾號:大白話AI,立志用大白話講懂AI。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM