寫在前面
本文翻譯自Tensorflow團隊的文章Tensorflow Control Flow Implementation,部分內容加入了筆者自己的理解,如有不妥之處還望各位指教。
目錄
- 概覽
- 控制流核心概念
- 控制流結構的編譯
- 條件表達式
- while循環
- 實現
- 分布式條件表達式
- 分布式while循環
- 自動微分
概覽
本文將會介紹當前在Tensorflow中控制流操作的設計和實現。這是一篇基於原始設計的描述性文檔,設計的細節還請參考源代碼。
本文將要講述的內容是:
- 介紹Tensorflow為了處理控制流加入的5個核心的操作;
- 展示高層的控制流是如何通過5個基礎操作融入數據流圖的;
- 解釋加入了控制流的數據流圖是怎樣被Tensorflow運行時執行的,包括融合了多種設備(CPU,GPU,TPU)的分布式執行方式;
- 描述了對控制流結構如何自動求導;
控制流核心概念
Tensorflow中控制流的基礎設計理念是,通過引入少量的簡單基礎操作,為多樣的Tensorflow應用提供豐富的控制流表達。我們期望這些操作靈活且富有表現力,能夠作為高層的領域專用語言(DSL,Domain Specific Language)的編譯目標。它們需要很方便的嵌入Tensorflow目前的數據流模型中,並且可以方便的進行並行的、分布式的執行以及自動求導。本節將介紹這5種控制流相關的基本操作。它們與Dennis和Arvind在數據流機(dataflow machines)中引入的控制流操作很像。使用Switch和Merge可以使我們事先條件控制,將這5種基礎操作組合起來,可以使我們實現while循環。
在Tensorflow中,每一個op都會在一個執行幀(execution frame)中被執行,控制流操作負責創建和管理這些執行幀。比如,對於while循環,Tensorflow的運行時會創建一個執行幀,然后將所有屬於該while循環的操作放在這個執行幀中執行。不同執行幀中的操作可以並行執行,只要它們之間沒有數據依賴。
Switch:一個Switch操作根據控制輸入p的布爾值,將一個輸入張量d推進到某一個輸出(二選一)。只有到Switch操作的兩個輸入都准備好之后,它才會執行。
Merge:Merge操作將它的其中一個輸入推向輸出。當一個Merge操作的任意一個輸入准備好之后,Merge操作就會執行。在多個輸入都准備好的情況下,Merge操作的輸出不確定。
Enter(name):Enter操作將它的輸入推向名為name的執行幀。Enter操作實際上是把一個執行幀的張量推向它的子執行幀。同一個子執行幀上可能會有多個Enter操作,它們將不同的張量推向子執行幀。當輸入准備好之后,Enter操作就會執行。一個新的執行幀在它的第一個Enter操作執行之后開始執行。
Exit:Exit操作,將一個張量從一個子執行幀推向它的父執行幀。它的作用是將張量從子執行幀返回給父執行幀。一個子執行幀可能有多個Exit操作指向父執行幀,每個操作都會異步的將一個張量返回給父執行幀。當它的輸入准備好之后,Exit操作開始執行。
NextIteration:NextIteration操作將一個張量從當前執行幀的一輪迭代傳遞到下一輪迭代。Tensorflow的運行時在執行幀內部保存了一個迭代輪數。任何一個在執行幀中執行的操作都有唯一的一個迭代輪數的屬性,它可以幫助我們分辨一個迭代運算中不同的執行輪次。注意在一個執行幀中可能會有多個NextIteration操作。當執行幀的第N輪執行的第一個NextIteration操作開始執行時,Tensorflow的運行時開始執行第N+1輪的迭代。當更多的張量通過了NextIteration操作進入新的執行輪次時,新執行輪次中更多的操作就會開始運行。當輸入准備完成之后,NextIteration操作開始執行。
控制流結構的編譯
有了這5種基礎的操作,高級的程序部件,例如條件表達式和whiile循環就可以被編譯進入數據流圖,然后被Tensorflow的運行時執行。下面我們來看一下條件表達式和while循環是如何在Tensorflow內部實現的。
條件表達式
以下是構建條件表達式cond(pred, fn1, fn2)的數據流圖的高層偽代碼。為了簡化,我們忽略了實際使用中的細節,讀者可以在control_flow_ops.py中找到實現細節:
//構建true分支圖
context_t = CondContext(pred, branch=1)
res_t = context_t.Call(fn1)
//構建false分支圖
context_t = CondContext(pred, branch=0)
res_f = context_f.Call(fn2)
//為輸出添加Merge節點
merges = [Merge([f,t]) for (f,t) in zip(res_f, res_t)]
return merges
對於條件表達式的每一個分支,我們創建了一個新的控制流上下文,並且在上下文中調用了圖構建的函數(fn1或者fn2)。條件上下文允許我們獲取任意的外部張量(不在上下文中創建的),並且插入一個合適的Switch操作來保證它會進入一個分支。這就保證了,只有當這個分支被選擇時,它對應的操作才會被執行。由於Tensorflow是異步執行的,外部的張量可能在不同的時間到達,因此我們為每一個外部張量准備了一個Switch操作來最大化並行度。
每個分支都返回了張量的列表(res_t或者res_f),因此我們又添加了一個Merge操作來對結果進行合並,這樣只要任何一個分支執行成功了,就能得到輸出(前面講到,對於Merge操作,只要其中一個輸入准備好了,就會產生輸出)。
讓我們來看一個簡單的例子:
tf.cond(x<y, lambda: tf.add(x,z), lambda: tf.square(y))
在生成的數據流圖中,Switch操作的插入是為了控制x,y,z張量的流動。在true/false分支,只有Switch操作的true/false的輸出才會被使用。由於Add操作的輸入來自Switch操作的true分支,因此只有x小於y時,Add操作才會被執行。同樣的,只有x大於等於y時,Square操作才會被執行。最終Merge操作發送Add或者Square的結果。如果條件表達式有多個結果,那么將會有多個Merge操作,每個結果對應一個Merge操作。
當然,利用Switch和Merge操作實現條件表達式還有很多方法,我們選擇當前的實現,主要是因為它更容易進行自動求導。
while循環
以下是構建數據流圖中while循環的高層偽代碼:
while_context = WhileContext()
while_context.Enter()
//為每一個循環變量添加Enter節點
enter_vars = [Enter(x, frame_name) for x in loop_vars]
//添加Merge節點,注意input[1]將會在后面被迭代
merge_vars = [Merge([x,x]) for x in enter_vars]
//構建循環條件子圖
pred_result = pred(*merge_vars)
//添加Switch節點
switch_vars = [Switch(x, pred_result) for x in merge_vars]
//構建循環體子圖
body_result = body(*[x[1] for x in switch_vars])
//添加NextIteration節點
next_vars = [NextIteration(x) for x in body_result]
//構建循環
for m,v in zip(merge_vars, next_vars):
m.op._update_input(1,v)
//添加Exit節點
exit_vars = [Exit(x[0]) for x in switch_vars]
while_context.Exit()
return exit_vars
整個while循環圖創建在while循環的控制流上下文中。整個思路比較簡單。
從循環變量開始,我們為它們分別添加一個Enter操作和一個Merge操作。我們使用它們的結果(merge_vars)來構建判斷子圖,從而計算循環終止條件。
在添加了Switch操作之后,我們使用Switch操作的true分支來構建循環體子圖。循環體的結果需要進入下一輪迭代,因此我們添加了一個NextIteration操作,並且將其輸出指向Merge操作的第二個輸入,這樣就形成了循環,允許我們在執行圖是不斷的運行同樣的一組操作。
Switch操作的false輸出是整個while循環的輸出,因此我們在它后面加入了Exit操作,來返回運算結果。與條件表達式類似,while循環的上下文被用來追蹤在pred和lambda中使用的外部張量。這些外部張量被看做是循環常數,我們自動為每一個外部張量插入了一個Enter操作,使它在while循環的上下文內部能夠被訪問。嵌套的循環需要添加嵌套的Enter操作。
同樣的,讓我們看一個簡單的例子:
tf.while_loop(lambda i:i<10, lambda i: tf.add(i,1),[0])
如上圖所示,我們只有一個循環變量。如果有多個循環變量,我們需要添加多個Enter,Merge,Switch,NextIteration和Exit操作。這使得跨循環和跨迭代輪次的執行成為可能。你可能注意到我們省略了常量的表示方法,如果你想要理解更深層次的細節,請查看源代碼。
這種對於條件表達式和while循環的支持,使得我們可以表達任意嵌套的條件和循環。例如,一個循環體內可能嵌套着另外一個循環體。TF保證每個循環被賦予了一個唯一的幀名稱。
實現
Tensorflow的運行時負責對數據流圖進行執行。下面我們先來對此做一個快速的概覽。
為了在多台設備上運行,TF自動將計算操作分配到不同的設備上。基於設備分配,TF自動的將數據流圖划分成子圖,每台設備有一個子圖對應。當數據流圖的一條邊被圖分割切段時(邊兩側的節點分配在兩台不同的設備上),我們自動的插入一對send和recv節點,以便在設備間傳輸數據。一對send和recv節點通過一個唯一的鍵實現通信,recv節點主動的從send節點拉取數據。例如,以下就是將原圖分割到兩台設備后的結果。TF對於分割沒有添加任何限制,只要一個節點能夠在一台設備上進行運算,就可以被分配到這台設備。
如果一個子圖被分配到一個設備上運行,那么這個設備將會使用隸屬於它的執行器來執行這個子圖。執行器從source節點開始,依次執行已經准備好的節點。除了Merge節點之外,對於任何一個其他節點來說,只要它的輸入准備好了,這個節點就可以開始執行了。注意一張子圖中所有的recv節點都被認為是source節點。
如果沒有控制流,圖執行的過程會非常的直接:每個節點僅被執行一次,並且當所有節點都執行結束之后,整個圖的執行就完成了。控制流的引入帶來了一定的復雜性。有了控制流,一個節點可能被執行任意次(甚至包括0次)。執行器需要管理對於同一個節點的多個同時存在的執行實例,並且決定計算圖合適執行結束。
為了追蹤計算中產生的張量,執行器中的張量被使用一個形如(value, is_dead, tag)的元組來標識,value是張量值,is_dead是一個布爾值,用來標識這個張量是否在一個未執行的條件分支上,tag是這個張量的唯一標識(產生張量的節點的執行實例)。本質上,tag定義了執行的上下文,在同一個執行上下文下,一個操作最多被執行一次。tag是send/recv之間通信的鍵的一部分,用來辨識同一對send/recv節點的不同執行。
執行器遵循了如下的執行規則(注意,某個節點的所有輸入都必須包含同樣的tag)
Switch(p,d) = (r1,r2)
r1 = (value(d), p || is_dead(d),tag(d))
r2 = (value(d), !p || is_dead(d),tag(d))
Merge(d1,d2) = r
r = if is_dead(d1) then d2 else d1
Enter(d, frame_name) = r
value(r) = value(d)
is_dead(r) = is_dead(d)
tag(r) = tag(d)/frame_name/0
Exit(d) = r
value(r) = value(d)
is_dead(r) = is_dead(d)
tag(r) = tag1 where tag(d)=tag1/frame_name/n
NextIteration(d) = d1
value(d1) = value(d)
is_dead(d1) = is_dead(d)
tag(d1) = tag1/frame_name/(n+1) where tag(d) = tag1/frame_name/n
Op(d1,...,dm) = (r1,...,rn)
value(ri) = Op.Compute(value(d1),...,value(dm)) if !is_dead(ri)
is_dead(ri) = any(is_dead(d1),...,is_dead(dm)), for all i
tag(ri) = tag(d1), for all i
最后一個規則適用於所有的非控制流節點。注意只有當所有的輸入都有效時,計算才會執行。如果有一個dead輸入,我們將會跳過計算,而將dead信號傳遞下去。對於dead信號的傳遞將有助於支持控制流的分布式執行。
分布式條件表達式
對於分布式執行來說,一個條件表達式可能被分配到了不同的設備上,如下圖所示:
由於每一個recv節點都是source節點,並且隨時可能會開始執行,在設備B上的recv節點甚至在出於未選擇的條件分支上時也會執行。為了讓出於未選擇的分支上的recv節點的執行合理化,我們將is_dead標簽通過send節點跨設備傳輸到recv節點。這種信息會一直跨越設備傳輸下去。這種簡單的傳輸機制使得在分布式環境下的條件判斷更加自然,也有助於分布式環境下的while循環。
分布式的while循環
在分布式環境下,一個while循環(特別是循環體),可能被分割到不同的設備上。如果我們簡單的應用分割邏輯,然后在跨設備的節點之間插入send/recv,那么設備上的局部執行器將缺少准確執行while循環的信息。
讓我們通過一個例子來認識這個問題。在上述例子中,Op在循環體中,並且被分配給了設備B。一個簡單的分割可能會在Switch和Op之間插入一對send/recv節點來執行跨設備的數據傳輸。然而,這樣是無法工作的,因為設備B並不知道recv和Op操作是處在一個循環當中的,在執行完Op一次之后,設備B上的執行器就會認為,它的工作已經完成了(從設備B的角度看,它只需要從recv獲取數據,執行Op,然后將結果通過send發送出去,執行就結束了)。解決方案是,重寫數據流圖,在while循環體分配到的每個設備上,添加一個控制循環狀態機(如下圖中所示)。標量0被用來作為Enter節點的輸入。
這些控制循環為設備上的執行器提供了足夠的信息,使得它們可以像以前一樣獨立的執行,同時通過send/recv與其它設備通信。注意到圖中的虛線代表了控制輸入。
(具體執行過程分為0次執行,和大於等於1次執行兩種情況討論,這里就不寫了,大家可以自行分析)
注意到執行中有非常多的並行執行。例如,在接收到P之后,設備B可以開始下一輪迭代,或者停止執行。一個設備可能同時存在並行的多個執行輪次,並且兩個不同的設備還可以同時處在同一個循環的不同迭代輪次上。
這種while循環的分布式執行方式帶來的開銷是,任何一個參與的設備都必須在每一個迭代輪次里,接收來自產生P的設備傳遞過來的布爾張量。由於執行過程是高度並行的,這種開銷可以忽略不計了。
下圖展示了當一個while循環被分割到不同的設備上時是什么樣子。每個分割的部分都被添加了一個控制循環結構,用來控制while循環內部的recv操作。重寫之后的新圖與原圖是語義等價的。
對於嵌套的while循環,我們按照下圖所示的方式將控制循環堆疊起來。注意如果一台設備僅包含了外層循環的節點,我們不會在它上面添加與內層循環有關的控制循環結構。
自動微分
待補充。