pytorch JIT淺解析


概要
  Torch Script中的核心數據結構是ScriptModule。 它是Torch的nn.Module的類似物,代表整個模型作為子模塊樹。 與普通模塊一樣,ScriptModule中的每個單獨模塊都可以包含子模塊,參數和方法。 在nn.Modules中,方法是作為Python函數實現的,但在ScriptModules方法中通常實現為Torch Script函數,這是一個靜態類型的Python子集,包含PyTorch的所有內置Tensor操作。 這種差異允許您運行ScriptModules代碼而無需Python解釋器。

ScriptModules和Torch Script函數可以通過兩種方式創建:
Tracing:
  使用torch.jit.trace,您可以獲取現有模塊或python函數,提供示例輸入,然后運行該函數,記錄在所有張量上執行的操作。 我們將生成的記錄轉換為Torch Script方法,該方法作為ScriptModule的正向方法安裝。 該模塊還包含原始模塊所具有的任何參數。
Example:

import torch
def foo(x, y):
return 2*x + y
traced_foo = torch.jit.trace(foo, (torch.rand(3), torch.rand(3)))
1
2
3
4
注意:

  由於跟蹤僅記錄張量上的操作,因此它不會記錄任何控制流操作,如if語句或循環。 當這個控制流在你的模塊中保持不變時,這很好,它通常只是內聯配置決策。 但有時控制流實際上是模型本身的一部分。 例如,序列到序列轉換中的波束搜索是輸入的(變化的)序列長度上的循環。 在這種情況下,跟蹤不合適,並且應使用腳本編寫波束搜索。


Scripting:
  您可以使用Python語法直接編寫Torch Script代碼。 您可以在ScriptModule的子類上使用torch.jit.script批注(對於函數)或torch.jit.script_method批注(對於方法)來執行此操作。 使用此注釋,注釋函數的主體將直接轉換為Torch腳本。 Torch腳本本身是Python語言的一個子集,因此並非python中的所有功能都可以工作,但我們提供了足夠的功能來計算張量並執行與控制相關的操作。
實例:

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.jit import ScriptModule, script_method, trace

class MyScriptModule(ScriptModule):
def __init__(self):
super(MyScriptModule, self).__init__()
# trace produces a ScriptModule's conv1 and conv2
self.conv1 = trace(nn.Conv2d(1, 20, 5), torch.rand(1, 1, 16, 16))
self.conv2 = trace(nn.Conv2d(20, 20, 5), torch.rand(1, 20, 16, 16))

@script_method
def forward(self, input):
input = F.relu(self.conv1(input))
input = F.relu(self.conv2(input))
return input
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
  用於將JIT模式PyTorch程序轉換為Torch腳本的API可在torch.jit模塊中找到。該模塊有兩種核心模式,用於將JIT模式模型轉換為Torch Script圖形表示:Tracing:和Scripting:。
  torch.jit.trace函數接受一個模塊或函數以及一組示例輸入。然后,它在跟蹤遇到的計算步驟時通過函數或模塊運行示例輸入,並輸出執行Tracing操作的基於圖形的函數。Tracing非常適用於不涉及數據相關控制流的簡單模塊和功能,例如標准卷積神經網絡。但是,如果Tracing具有依賴於數據的if語句和循環的函數,則僅記錄由示例輸入執行的執行路徑調用的操作。換句話說,不捕獲控制流本身。 為了轉換包含依賴於數據的控制流的模塊和函數,提供了一種 Script機制。
  Script顯式將模塊或功能代碼轉換為Torch Script,包括所有可能的控制流路徑。 要使用腳本模式,請確保從torch.jit.ScriptModule基類(而不是torch.nn.Module)繼承,並將torch.jit.script裝飾器添加到Python函數或torch.jit.script_method裝飾器中。你的模塊的方法。使用腳本的一個警告是它只支持Python的受限子集。下面會描述當前pytorch JIT支持的功能的所有詳細信息。為了提供最大的靈活性,可以組合Torch腳本的模式來表示整個程序,並且可以逐步應用這些技術。

TORCH SCRIPT LANGUAGE REFERENCE
Torch Script是Python的一個子集,可以直接編寫(使用@script注釋),也可以通過跟蹤從Python代碼自動生成。 使用跟蹤時,代碼會自動轉換為Python的這個子集,方法是僅記錄張量上的實際運算符,並簡單地執行和丟棄其他周圍的Python代碼。

使用@script注釋直接編寫Torch腳本時,程序員必須只使用Torch腳本支持的Python子集。 本節介紹了Torch Script支持的內容,就好像它是獨立語言的語言參考一樣。 本參考中未提及的Python的任何功能都不是Torch腳本的一部分。

作為Python的一個子集,任何有效的Torch Script函數也是一個有效的Python函數。 這樣就可以刪除@script注釋並使用標准Python工具(如pdb)調試函數。 反之亦然:有許多有效的python程序不是有效的Torch Script程序。 相反,Torch Script專注於在Torch中表示神經網絡模型所需的Python特性。

PYTORCH_JIT= 1
設置環境變量PYTORCH_JIT = 0將禁用所有腳本和跟蹤注釋。 如果其中一個ScriptModule中存在難以調試的錯誤,則可以使用此標志強制所有內容都使用本機Python運行。 這允許使用像pdb這樣的工具來調試代碼。

1.Types,支持的類型
Torch Script與完整Python語言之間的最大區別在於Torch Script僅支持表達神經網絡模型所需的一小部分類型。 特別是Torch Script支持:

Tensor
  任何dtype,dimension或backend的PyTorch Tensor。

Tuple[T0, T1, …]
  包含子類型T0,T1等的元組(例如Tuple[Tensor,Tensor])

int
  int標量

float
  float 標量

List[T]
  所有成員都是T類的列表與Python不同,Torch Script函數中的每個變量都必須具有單個靜態類型。 這樣可以更輕松地優化Torch Script功能。
Example,下面這種情況應該避免,返回類型不一致:

@torch.jit.script
def an_error(x):
if x:
r = torch.rand(1)
else:
r = 4
return r # Type mismatch: r is set to type Tensor in the true branch
# and type int in the false branch
1
2
3
4
5
6
7
8
默認情況下,假定Torch腳本函數的所有參數都是Tensor,因為這是模塊中最常用的類型。 要指定Torch腳本函數的參數是另一種類型,可以使用上面列出的類型使用MyPy樣式類型注釋:

@torch.jit.script
def foo(x, tup):
# type: (int, Tuple[Tensor, Tensor]) -> Tensor
t0, t1 = tup
return t0 + t1 + x

print(foo(3, (torch.rand(3), torch.rand(3))))
1
2
3
4
5
6
7
Tips:也可以使用Python 3類型注釋來注釋類型。 在我們的示例中,我們使用基於注釋的注釋來確保Python 2的兼容性。
1
2.Expressions,表示
支持以下Python表達式

Literals,常量:
True, False, None, 'string literals', "string literals", number literals 3 (interpreted as int) 3.4 (interpreter as a float)
1
Variables,變量:
a

Variable Resolution,變量分辨能力
  Torch Script支持Python的可變分辨率(即范圍)規則的子集。 局部變量的行為與Python中的相同,除了變量必須在函數的所有路徑中具有相同類型的限制。 如果變量在if語句的不同側具有不同的類型,則在if語句結束后使用它是錯誤的。

類似地,如果僅沿着函數的某些路徑定義變量,則不允許使用該變量。

@torch.jit.script
def foo(x):
if x < 0:
y = 4
print(y) # Error: undefined value y
1
2
3
4
5
定義函數時,非局部變量在編譯時解析為Python值。 然后,使用“使用Python值”中描述的規則將這些值轉換為Torch Script值。

Tuple Construction
(3, 4), (3,)

List Construction
[3, 4], [], [torch.rand(3), torch.rand(4)]

假設空列表具有類型List [Tensor]。 其他列表文字的類型是從成員的類型派生的。

Arithmetic Operators
a + b a - b a * b a / b a ^ b a @ b

Comparison Operators
a == b a != b a < b a > b a <= b a >= b

Logical Operators
a and b a or b not b

Subscripts
t[0] t[-1] t[0:2] t[1:] t[:1] t[:] t[0, 1] t[0, 1:2] t[0, :1] t[-1, 1:, 0] t[1:, -1, 0] t[i:j, i]

Torch Script目前不支持變異張量,因此任何張量索引只能出現在表達式的右側size上。

Function calls
調用內置函數: torch.rand(3, dtype=torch.int)
調用其他script函數:

import torch

@torch.jit.script
def foo(x):
return x + 1

@torch.jit.script
def bar(x):
return foo(x)
1
2
3
4
5
6
7
8
9
Method calls
調用內置類型的方法,如Tensor:x.mm(y)

在ScriptModule中定義Script方法時,使用@script_method批注。 在這些方法中,可以調用此類的其他方法或訪問子模塊上的方法。

直接調用子模塊(例如self.resnet(輸入))等同於調用其正向方法(例如self.resnet.forward(input))

import torch

class MyScriptModule(torch.jit.ScriptModule):
def __init__(self):
super(MyScriptModule, self).__init__()
self.means = torch.nn.Parameter(torch.tensor([103.939, 116.779, 123.68])
.resize_(1, 3, 1, 1))
self.resnet = torch.jit.trace(torchvision.models.resnet18(),
torch.rand(1, 3, 224, 224))

@torch.jit.script_method
def helper(self, input):
return self.resnet(input - self.means)

@torch.jit.script_method
def forward(self, input):
return self.helper(input)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
If expressions
x if x > y else y

Casts
float(ten), int(3.5), bool(ten)

Accessing Module Parameters
self.my_parameter self.my_submodule.my_parameter

3.Statements
Torch Script支持以下類型的語句:

Simple Assignments 簡單的賦值
a = b
a += b # short-hand for a = a + b, does not operate in-place on a
a -= b

Pattern Matching Assignments
a, b = tuple_or_list
a, b, *c = a_tuple

Print Statements
print(“the result of an add:”, a + b)

If Statements
if a < 4:
r = -a
elif a < 3:
r = a + a
else:
r = 3 * a
1
2
3
4
5
6
While Loops
a = 0
while a < 4:
print(a)
a += 1
1
2
3
4
For loops with range
x = 0
for i in range(10):
x *= i
1
2
3
NOTE:Script當前不支持迭代通用可迭代對象,如list或tensor。 腳本當前不支持啟動或增加范圍的參數。 這些將在未來版本中添加。

For loops over tuples:
tup = (3, torch.rand(4))
for x in tup:
print(x)
1
2
3
Note:對於tuples的循環將展開循環,為tuples的每個成員生成一個主體。 正文必須為每個成員正確地進行類型檢查。

For loops over constant torch.nn.ModuleList
class SubModule(torch.jit.ScriptModule):
def __init__(self):
super(Sub, self).__init__()
self.weight = nn.Parameter(torch.randn(2))

@torch.jit.script_method
def forward(self, input):
return self.weight + input

class MyModule(torch.jit.ScriptModule):
__constants__ = ['mods']

def __init__(self):
super(MyModule, self).__init__()
self.mods = torch.nn.ModuleList([SubModule() for i in range(10)])

@torch.jit.script_method
def forward(self, v):
for module in self.mods:
v = m(v)
return v
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
要在@script_method中使用ModuleList,必須通過將屬性的名稱添加到該類型的__constants__列表來將其標記為常量。 對於ModuleList上的循環,將在編譯時使用常量模塊列表的每個成員展開循環體。

Return
return a, b

Note:必須有一個return語句作為函數的最后一個成員,並且return語句不能出現在函數的任何其他位置。 此限制將在以后刪除。

4.Debugging
Disable JIT for Debugging
如果要禁用所有JIT模式(跟蹤和腳本),以便可以在原始Python中調試程序,則可以使用PYTORCH_JIT環境變量。 PYTORCH_JIT可以通過將其值設置為0來全局禁用JIT。給出一個示例腳本:

@torch.jit.script
def scripted_fn(x : torch.Tensor):
for i in range(12):
x = x + x
return x


def fn(x):
x = torch.neg(x)
import pdb; pdb.set_trace()
return scripted_fn(x)

traced_fn = torch.jit.trace(fn, (torch.rand(4, 5),))

traced_fn(torch.rand(3, 4))
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
除了調用@script函數之外,使用PDB調試此腳本的工作原理除外。 我們可以全局禁用JIT,這樣我們就可以將@script函數作為普通的python函數調用而不是編譯它。 如果上面的腳本名為disable_jit_example.py,我們可以像這樣調用它:

$ PYTORCH_JIT=0 python disable_jit_example.py
1
我們將能夠作為普通的Python函數進入@script函數。

Interpreting Graphs,解釋圖表
TorchScript使用靜態單一賦值(SSA)中間表示(IR)來表示計算。 這種格式的指令包括ATen(PyTorch的C ++后端)運算符和其他原始運算符,包括循環和條件的控制流運算符。 舉個例子:

@torch.jit.script
def foo(len):
# type: (int) -> torch.Tensor
rv = torch.zeros(3, 4)
for i in range(len):
if i < 10:
rv = rv - 1.0
else:
rv = rv + 1.0
return rv

print(foo.graph)
1
2
3
4
5
6
7
8
9
10
11
12
具有單個forward方法的ScriptModule將具有屬性圖,您可以使用該圖來檢查表示計算的IR。 如果ScriptModule有多個方法,則需要訪問方法本身的.graph而不是模塊。 我們可以通過訪問.bar.graph來檢查ScriptModule上名為bar的方法的圖形。
上面的示例腳本生成圖形:

graph(%len : int) {
%13 : float = prim::Constant[value=1]()
%10 : int = prim::Constant[value=10]()
%2 : int = prim::Constant[value=4]()
%1 : int = prim::Constant[value=3]()
%3 : int[] = prim::ListConstruct(%1, %2)
%4 : int = prim::Constant[value=6]()
%5 : int = prim::Constant[value=0]()
%6 : int[] = prim::Constant[value=[0, -1]]()
%rv.1 : Dynamic = aten::zeros(%3, %4, %5, %6)
%8 : int = prim::Constant[value=1]()
%rv : Dynamic = prim::Loop(%len, %8, %rv.1)
block0(%i : int, %12 : Dynamic) {
%11 : int = aten::lt(%i, %10)
%rv.4 : Dynamic = prim::If(%11)
block0() {
%14 : int = prim::Constant[value=1]()
%rv.2 : Dynamic = aten::sub(%12, %13, %14)
-> (%rv.2)
}
block1() {
%16 : int = prim::Constant[value=1]()
%rv.3 : Dynamic = aten::add(%12, %13, %16)
-> (%rv.3)
}
%19 : int = prim::Constant[value=1]()
-> (%19, %rv.4)
}
return (%rv);
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
以指令%rv.1:Dynamic = aten :: zeros(%3,%4,%5,%6)為例。 %rv.1:動態意味着我們將輸出分配給名為rv.1的(唯一)值,並且該值是動態類型,即我們不知道其具體形狀。 aten :: zeros是運算符(相當於torch.zeros),輸入列表(%3,%4,%5,%6)指定范圍中的哪些值應作為輸入傳遞。 內置函數(如aten :: zeros)的模式可以在Builtin Functions中找到。

Builtin Functions:
  Torch Script支持PyTorch提供的內置張量和神經網絡函數的子集。 Tensor上的大多數方法以及torch命名空間中的函數都可用。 torch.nn.functional中的許多功能也是可用的。
  我們目前不提供任何內置的ScriptModule,例如Linear或Conv模塊。 此功能將在未來開發。 目前我們建議使用torch.jit.trace將標准的torch.nn模塊轉換為構造中的ScriptModules。

請注意,運算符也可以有關聯的塊,即prim :: Loop和prim :: If運算符。 在圖形打印輸出中,這些運算符被格式化以反映其等效的源代碼形式,以便於調試。
可以如圖所示檢查圖形以確認ScriptModule描述的計算以自動和手動方式是正確的,如下所述。

Tracing Edge Cases
存在一些邊緣情況,其中給定Python函數/模塊的跟蹤將不代表底層代碼。 這些案件可包括:

Tracing依賴於輸入的控制流(例如tensor的shapes)
Tracing Tensor視圖的就地操作(例如,在左側索引的賦值)
請注意,這些情況實際上可能在將來可被Trace。
Automatic Trace Checking
自動捕獲跟蹤中的許多錯誤的一種方法是使用torch.jit.trace()API上的check_inputs。 check_inputs獲取一系列輸入元組列表,這些元組將用於重新跟蹤計算並驗證結果。 例如:

def loop_in_traced_fn(x):
result = x[0]
for i in range(x.size(0)):
result = result * x[i]
return result

inputs = (torch.rand(3, 4, 5),)
check_inputs = [(torch.rand(4, 5, 6),), (torch.rand(2, 3, 4),)]

traced = torch.jit.trace(loop_in_traced_fn, inputs, check_inputs=check_inputs)
1
2
3
4
5
6
7
8
9
10
提供以下診斷信息:

ERROR: Graphs differed across invocations!
Graph diff:
graph(%0 : Dynamic) {
%1 : int = prim::Constant[value=0]()
%2 : int = prim::Constant[value=0]()
%3 : Dynamic = aten::select(%0, %1, %2)
%4 : int = prim::Constant[value=0]()
%5 : int = prim::Constant[value=0]()
%6 : Dynamic = aten::select(%0, %4, %5)
%7 : Dynamic = aten::mul(%3, %6)
%8 : int = prim::Constant[value=0]()
%9 : int = prim::Constant[value=1]()
%10 : Dynamic = aten::select(%0, %8, %9)
%11 : Dynamic = aten::mul(%7, %10)
%12 : int = prim::Constant[value=0]()
%13 : int = prim::Constant[value=2]()
%14 : Dynamic = aten::select(%0, %12, %13)
%15 : Dynamic = aten::mul(%11, %14)
+ %16 : int = prim::Constant[value=0]()
+ %17 : int = prim::Constant[value=3]()
+ %18 : Dynamic = aten::select(%0, %16, %17)
+ %19 : Dynamic = aten::mul(%15, %18)
- return (%15);
? ^
+ return (%19);
? ^
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
此消息向我們表明,在我們第一次跟蹤它和使用check_inputs跟蹤它時,計算之間存在差異。 實際上,loop_in_traced_fn體內的循環取決於輸入x的形狀,因此當我們嘗試另一個具有不同形狀的x時,跡線會有所不同。

在這種情況下,可以使用腳本來捕獲這樣的數據相關控制流:

def fn(x):
result = x[0]
for i in range(x.size(0)):
result = result * x[i]
return result

inputs = (torch.rand(3, 4, 5),)
check_inputs = [(torch.rand(4, 5, 6),), (torch.rand(2, 3, 4),)]

scripted_fn = torch.jit.script(fn)
print(scripted_fn.graph)

for input_tuple in [inputs] + check_inputs:
torch.testing.assert_allclose(fn(*input_tuple), scripted_fn(*input_tuple))
1
2
3
4
5
6
7
8
9
10
11
12
13
14
那就會產生:

graph(%x : Dynamic) {
%1 : int = prim::Constant[value=0]()
%2 : int = prim::Constant[value=0]()
%result.1 : Dynamic = aten::select(%x, %2, %1)
%4 : int = aten::size(%x, %1)
%5 : int = prim::Constant[value=1]()
%result : Dynamic = prim::Loop(%4, %5, %result.1)
block0(%i : int, %7 : Dynamic) {
%9 : int = prim::Constant[value=0]()
%10 : Dynamic = aten::select(%x, %9, %i)
%result.2 : Dynamic = aten::mul(%7, %10)
%12 : int = prim::Constant[value=1]()
-> (%12, %result.2)
}
return (%result);
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
Tracer Warnings
跟蹤器在跟蹤計算中為幾個有問題的模式生成警告。 例如,在Tensor的切片(視圖)上跟蹤包含就地賦值的函數:

def fill_row_zero(x):
x[0] = torch.rand(*x.shape[1:2])
return x

traced = torch.jit.trace(fill_row_zero, (torch.rand(3, 4),))
print(traced.graph)
1
2
3
4
5
6
生成幾個警告和一個只返回輸入的圖表:

fill_row_zero.py:4: TracerWarning: There are 2 live references to the data region being modified when tracing in-place operator copy_ (possibly due to an assignment). This might cause the trace to be incorrect, because all other views that also reference this data will not not reflect this change in the trace! On the other hand, if all other views use the same memory chunk, but are disjoint (e.g. are outputs of torch.split), this might still be safe.
x[0] = torch.rand(*x.shape[1:2])
fill_row_zero.py:6: TracerWarning: Output nr 1. of the traced function does not match the corresponding output of the Python function. Detailed error:
Not within tolerance rtol=1e-05 atol=1e-05 at input[0, 1] (0.09115803241729736 vs. 0.6782537698745728) and 3 other locations (33.00%)
traced = torch.jit.trace(fill_row_zero, (torch.rand(3, 4),))
graph(%0 : Float(3, 4)) {
return (%0);
}
1
2
3
4
5
6
7
8
我們可以通過修改代碼以不使用就地更新來修復此問題,而是使用torch.cat構建結果張量:

def fill_row_zero(x):
x = torch.cat((torch.rand(1, *x.shape[1:2]), x[1:2]), dim=0)
return x

traced = torch.jit.trace(fill_row_zero, (torch.rand(3, 4),))
print(traced.graph)
---------------------
作者:丶Shining
來源:CSDN
原文:https://blog.csdn.net/xxradon/article/details/86504906
版權聲明:本文為博主原創文章,轉載請附上博文鏈接!


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM