pytorch jit的學習
TorchScript:
TorchScript是一個靜態類型的Python子集,可以直接編寫(使用@torch.jit。
腳本裝飾器)或通過跟蹤從Python代碼自動生成。
在使用跟蹤時,通過只記錄張量上的實際操作符,並簡單地執行和丟棄周圍的其他Python代碼,代碼會自動轉換為Python的這個子集。
當使用@torch.jit直接編寫TorchScript時。
腳本裝飾器,程序員必須只使用TorchScript中支持的Python子集。
本節記錄了TorchScript支持的內容,就好像它是一種獨立語言的語言參考。
本參考中未提及的Python特性都不是TorchScript的一部分。
有關可用Pytorch張量方法、模塊和函數的完整參考,請參閱內置函數。
作為Python的子集,任何有效的TorchScript函數也是一個有效的Python函數。
這使得禁用TorchScript和使用標准Python工具(如pdb)調試該函數成為可能。
反之則不然:有許多有效的Python程序不是有效的TorchScript程序。
相反,TorchScript專門關注Python的一些特性,這些特性需要在PyTorch中表示神經網絡模型.
以上節選自pytorch官網介紹
簡而言之:pytorch script 以一種特定語言描述從python導出模型,並可在任意非python環境中導入使用
簡單案例
import torch
import torchvision
class MyScriptModule(torch.nn.Module):
def __init__(self):
super(MyScriptModule, self).__init__()
self.means = torch.nn.Parameter(torch.tensor([103.939, 116.779, 123.68])
.resize_(1, 3, 1, 1))
self.resnet = torch.jit.trace(torchvision.models.resnet18(),
torch.rand(1, 3, 224, 224))
def forward(self, input):
return self.resnet(input - self.means)
my_script_module = torch.jit.script(MyScriptModule())
利用裝飾器:
import torch
# 跟蹤函數
@torch.jit.script
def foo(x, tup):
# type: (int, Tuple[Tensor, Tensor]) -> Tensor
t0, t1 = tup
return t0 + t1 + x
print(foo(3, (torch.rand(3), torch.rand(3))))
常用操作可見官網:
TorchScript
script, trace差異
import torch
def foo(x, y):
return 2*x + y
traced_foo = torch.jit.trace(foo, (torch.rand(3),torch.rand(3)))
trace僅記錄張量上的操作,因此它不會記錄任何控制流操作,如if語句或循環。
當你的模型涉及復雜控制流操作,得用script
@torch.jit.script
def foo(len):
# type: (int) -> torch.Tensor
rv = torch.zeros(3, 4)
for i in range(len):
if i < 10:
rv = rv - 1.0
else:
rv = rv + 1.0
return rv
torch.script 保存讀取
Script中的核心數據結構是ScriptModule。 它是Torch的nn.Module的類似物,代表整個模型作為子模塊樹。 與普通模塊一樣,ScriptModule中的每個單獨模塊都可以包含子模塊,參數和方法
對於sciptmodule的保存跟普通moudle類似
cpu_model = gpu_model.cpu()
sample_input_cpu = sample_input_gpu.cpu()
traced_cpu = torch.jit.trace(cpu_model, sample_input_cpu)
# 利用torch.jit.save保存模型
torch.jit.save(traced_cpu, "cpu.pt")
traced_gpu = torch.jit.trace(gpu_model, sample_input_gpu)
torch.jit.save(traced_gpu, "gpu.pt")
# ... later, when using the model:
if use_gpu:
# 對應利用jit.load讀取模型
model = torch.jit.load("gpu.pt")
else:
model = torch.jit.load("cpu.pt")
model(input)
Script Module的解釋圖表
TorchScript使用靜態單一賦值(SSA)中間表示(IR)來表示計算。 這種格式的指令包括ATen(PyTorch的C ++后端)運算符和其他原始運算符,包括循環和條件的控制流運算符。 舉個例子:

code屬性:

graph屬性:

打印子層
idx = 0
for name, cr in m.named_children():
print(f"{idx} layer: {name}")
print(cr)
idx+=1
額外小知識
利用torch.jit.save保存的.pth文件可以通過壓縮軟件打開,可以直接看到里面的code

pytorch jit中的一些優化
torch._C._jit_set_profiling_mode()
torch.jit.optimized_execution()
參考:
