一、計算圖簡介
在pytorch的官網上,可以看到一個簡單的計算圖示意圖, 如下。
import torch
from torch.autograd import Variable x = Variable(torch.randn(1, 10)) prev_h = Variable(torch.randn(1, 20)) W_h = Variable(torch.randn(20, 20)) W_x = Variable(torch.randn(20, 10)) i2h = torch.mm(W_x, x.t()) h2h = torch.mm(W_h, prev_h.t()) next_h = i2h + h2h next_h = next_h.tanh()
這個圖里有兩種節點:Variable節點和Function節點,Variable記錄運算數據,Function記錄運算操作。其中Variable節點又可以分為葉節點和非葉節點兩類。葉節點由用戶直接創建產生,而非葉節點則由Variable節點之間的運算操作產生,在圖的代碼中,x、prev_h、W_h、W_x屬於葉節點,i2h、h2h、next_h屬於非葉節點。
在這個圖上,節點之間的關系是很明確的:Variable非葉節點指向產生它的Function,因為產生某個Variable的Function只可能有一個,因此一個Variable只指向一個Function。Function的指向則是可以一對多的,因為一個運算函數往往可以接受大量的參數。Function指向兩種節點,當Function接受一個葉節點的Variable輸入時,Function需指向此Variable,當Function接受一個非葉節點Variable輸入時,Function需指向此Variable所指向的那個Function。
那這個計算圖是怎么建立的,具體實現又是怎么樣的呢?我們通過從頂向下、從底至上的兩個視角分別切入,來研究計算圖的形成過程。
二、框架性視角
2.1 Variable與Function
首先從頂至下,大略框架性地了解一下pytorch自動求導模塊幾個類,其中最重要的便是Variable類和Function類。此處注意的是,因為在C++代碼中與python代碼中均有名為Variable、Function的類,為示區別,如在不同語言中的類名有重復,在C++中的類稱為如Variable(C++),在python中的類稱為如Variable(py),當不刻意去分辨兩者的區別時則不特意加后綴括號,以此類推。
這里提到的類是Varaible(C++)和Function(C++),分別定義在torch/csrc/autograd/variable.h、torch/csrc/autograd/function.h,此處不復制粘貼代碼了。Variable(C++)類為自動求導過程中的核心數據類,與gif中的Variable節點可對應;Fucntion(C++)為自動求導過程中的函數類,與gif中的Fcuntion節點可對應。
首先解釋Variable(C++)類的幾個成員變量,
std::unique_ptr<thpp::Tensor> data:這是具體的底層數據,Variable(C++)為了完成自動求導的任務,在持有這個數據對象的條件下進行了一些包裝。
bool requires_grad,bool is_volatile:是兩個求導選項,在Variable(python)的構造函數中,可以傳中兩個選項參數。對於葉節點,如果有x.requires_grad==False或者x.volatile==True,則不需要對x進行求導。然而,這兩個選項的不同在於,對於非葉節點,當產生非葉節點的所有Variable節點的requires_grad都是False時,它的requires_grad才是False,而只要有一個產生它的Variable節點的volatile是True,那它的volatile就是True。前者通常用於固定某個不需或暫時不需迭代參數的模塊,如遷移學習、GAN訓練等場景中的常見情形;后者則通常用於明確地確認此部分任務不需要執行反向求導的情形,如一個深度學習模型的測試過程。這兩個選項可能會產生沖突,當沖突時,以volatile為准。其它閱讀資料可參考http://pytorch.org/docs/master/notes/autograd.html
std::shared_ptr<Function> grad_fn:這就是計算圖中,由Vairbale節點指向Function節點的連接。看命名的方式,是grad_fn,即gradient function。說明實際上,Variable節點與Function節點之間的連接,並不完全像gif中所示,連接前向計算時調用的函數對象,而是連接對應的梯度函數。這個連接是如何被建立的,將是文章后半段重點探索的內容。
然后考慮Function(C++)類的幾個成員變量,
bool is_executable:這是Funciton節點中,類似於Varibale節點里requires_grad、is_volatile標識的一個成員變量。如果給某個Function節點輸入的所有Vairbale節點都有requires_grad==False,或者少有一個的volatile==True,那這個Function的is_executable就會為False。結合Variable(C++)和Function(C++)求導標識的邏輯,可以在Function::flags方法里看見,邏輯正如此前描述一樣。
function_list next_functions:在Function(C++)定義的同文件內,也有function_list的定義,using function_list = std::vector<std::pair<std::shared_ptr<Function>, int>>。可見,通過next_functions可以訪問到一系列的Function(C++)對象,直觀地推斷,它就是gif圖中,Function節點與Function節點連接的關鍵。這個連接是如何被建立的,將是文章后半段重點探索的內容。
這里我們留了一個小疑問,在之前對gif的非嚴謹分析里,有葉節點輸入的Function節點,其會有一個指向葉節點對象的連接,但是在Function類里沒有發現有對應的成員變量,那這里是如何實現的呢,可留待具體實現分析的時候查看。
2.2 從C++到Python,以Function類為例
這一段希望解釋Function(C++)類與THPFunction等類的關系,不可避免地涉及部分python-C API的內容,但講解得較粗略,關於這部分詳細的內容可以去閱讀 專門講解python的C擴展的文章。
C++中,除了擁有底層邏輯的類以外,還有一層向python包裝的中間類,比如,Function(C++)類就是通過一個THPFunction類、一個PyTypeObject類實例THPFunctionType,包裝成一個python里可以訪問的torch._C._FunctionBase類的。這幾個類(或實例)之間的關系是什么呢?
首先看THPFunction類,它定義在torch/csrc/autograd/python_fucntion.h,THPFunction類持有一個PyFunction對象,而PyFunction類在同文件內定義並繼承Function類。THPFunction類的其它成員變量中,有部分是PyObject*類的,這部分通常被設計於暴露給python層,還有一部分不是PyObject*類的,它們在python層中不可見,僅在C++層的代碼邏輯中進行運作。
struct THPFunction { PyObject_HEAD PyObject *needs_input_grad; // Python tuple of tensors whose variables we should save. Set // by Python with 'save_for_backward'. If NULL, no tensors were // saved. PyObject *to_save; // Python pairs of distinct tensors which share storage. Set by // Python with 'mark_shared_storage'. If NULL, no tensors share // storage. PyObject *shared_pairs; // Python tuple of tensors which are not differentiable. Set by // Python with 'mark_non_differentiable'. If NULL, no tensors were // non-differentiable. PyObject *non_differentiable; // Python tuple of tensors which had inplace updates in the forward() // pass. Set by Python with 'mark_dirty'. If NULL, no tensors were // modified inplace. PyObject *dirty_tensors; std::vector<output_info_type> *output_info; std::vector<torch::autograd::SavedVariable> *saved_variables; // For each input, true if the input is a THPVariable std::vector<bool> *is_variable_input; char has_freed_buffers; // The C++ wrapper for this Python function. // See a comment in THPFunction_asFunction for details about this field. torch::autograd::PyFunction cdata; };
再看PyTypeObject類的實例THPFunctionType,它定義在torch/csrc/autograd/python_fucntion.cpp中,從注釋上可以看出來,它定義了一個python類的諸多基本操作。比如,如果python層創建一個對象的時候,要知道需要分配多大的空間,就到PyTypeObject負責tp_basicsize的那個slot里面去找,在這個個例里,它的值是sizeof(THPFunction);又如,這個類封裝到python層以后,有哪些方法呢,這個可以在tp_methods的這個域找到,在這個個例里,它的值是THPFunction_properties,THPFunction_properties這個變量也定義在同樣的文件夾下,它負責把C++的函數映射成python的類方法。
PyTypeObject THPFunctionType = { PyVarObject_HEAD_INIT(NULL, 0) "torch._C._FunctionBase", /* tp_name */ sizeof(THPFunction), /* tp_basicsize */ 0, /* tp_itemsize */ (destructor)THPFunction_dealloc, /* tp_dealloc */ 0, /* tp_print */ 0, /* tp_getattr */ 0, /* tp_setattr */ 0, /* tp_reserved */ 0, /* tp_repr */ 0, /* tp_as_number */ 0, /* tp_as_sequence */ 0, /* tp_as_mapping */ 0, /* tp_hash */ 0, /* tp_call */ 0, /* tp_str */ 0, /* tp_getattro */ 0, /* tp_setattro */ 0, /* tp_as_buffer */ Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HAVE_GC, /* tp_flags */ NULL, /* tp_doc */ (traverseproc)THPFunction_traverse, /* tp_traverse */ (inquiry)THPFunction_clear, /* tp_clear */ 0, /* tp_richcompare */ 0, /* tp_weaklistoffset */ 0, /* tp_iter */ 0, /* tp_iternext */ THPFunction_methods, /* tp_methods */ 0, /* tp_members */ THPFunction_properties, /* tp_getset */ 0, /* tp_base */ 0, /* tp_dict */ 0, /* tp_descr_get */ 0, /* tp_descr_set */ 0, /* tp_dictoffset */ 0, /* tp_init */ 0, /* tp_alloc */ THPFunction_new /* tp_new */ };
_FunctionBase這個類是在哪里創建的呢,實際就在上述同樣的文件里,THPFunctionType定義下方,可見函數THPFunction_initModule,其中有一句調用PyModule_AddObject,如此便注冊了一個新的python類_FunctionBase
bool THPFunction_initModule(PyObject *module) { if (PyType_Ready(&THPFunctionType) < 0) return false; Py_INCREF(&THPFunctionType); PyModule_AddObject(module, "_FunctionBase", (PyObject *)&THPFunctionType); return true; }
我們看到PyModue_AddObject調用的參數里,只傳了THPFunctionType這個對象,卻沒見到THPFunction相關的信息,那_FunctionBase是怎么樣會與THPFunction扯上關系的呢?答案是通過THPFunctionType的各個slot下變量的具體定義。比如,tp_new這個slot下,值為THPFunction_new,THPFunction_new在同樣的文件下定義,它是一個函數
PyObject *THPFunction_new(PyTypeObject *type, PyObject *args, PyObject *kwargs) { PyObject* obj = type->tp_alloc(type, 0); if (!obj) return NULL; // Python zero-initializes the object memory, so there's no need to initialize // most fields THPFunction* self = (THPFunction*)obj; new (&self->cdata) PyFunction(obj); self->cdata.num_inputs = -1; self->cdata.is_stochastic = PyObject_IsInstance(obj, THPStochasticFunctionClass); return obj; }
在這個函數體的第5行,它把新分配的內存強轉為了THPFunction *並賦值出去,這就實現了THPFunctionType對象和THPFunction類的聯系,在其它的屬性操作上也是如此。
總結來說,Function(C++)類定義了一個類的底層邏輯;THPFunction類會持有一個Function(C++)類對象,並暴露了一些可以在python層訪問的數據;PyTypeObject類的對象THPFunctionType里,各個slot定義了在python層里這個類的諸多基本操作(包括構造、析構、成員變量、方法等等等等);_FunctionBase是被包裝好后的python類,在python中可以通過import torch._C._FunctionBase訪問到它。以此類推地,Variable(C++)類、THPVariable類、PyTypeObject類的對象THPVariableType、_VariableBase也是類似的關系。
至於python層的Variable(py)類、Function(py)類,分別被定義在torch/autograd/variable.py、torch/autograd/function.py里,可以看到它們分別跟_VariableBase、_FunctionBase有一定的繼承關系
class Variable(_C._VariableBase): ............................ ............................ ............................
class Function(with_metaclass(FunctionMeta, _C._FunctionBase, _ContextMethodMixin, _HookMixin)):
............................
............................
............................
三、實現細節
我們了解pytorch中幾個關鍵類的關系與結構之后,通過自底向下,從具體代碼追蹤調用的方法,探索整個計算圖是怎么樣形成的。
3.1 python層代碼
首先,假設有下面的python代碼,並得到結果
import torch from torch.autograd import Variable x = Variable(torch.Tensor([[1, 2, 3], [4, 5, 6]]), requires_grad=True) y = x.prod(dim=1, keepdim=True) print(y.grad_fn) print(y.grad_fn.next_functions)
<torch.autograd.function.ProdBackward object at 0x000001F18FB912F8>
((<AccumulateGrad object at 0x000001F18FBC2710>, 0),)
可以看到,通過簡單的創建數據、運算賦值這兩行代碼,計算圖已經建立起來了。這個操作的背后具體是怎么建立起來的呢,先從x.prod()這一方法的定義追蹤起。在torch/autograd/variable.py里可以看見
class Variable(_C._VariableBase): ……………………… def prod(self, dim=None, keepdim=None): return Prod.apply(self, dim, keepdim) ………………………
根據追蹤,可以發現這里提及的Prod類定義在torch/autograd/_functions/reduce.py里,繼承Function(py)類,有兩個類方法forward、backward,但是沒有顯式地定義apply方法。不管如何,此前的賦值語句可以視為
from torch.autograd._functions import Prod y = Prod.apply(x, 1, True) # same as y = x.prod(dim=1, keepdim=True)
因為Prod沒有顯示地定義apply方法,所以我們需要到它的父類里找apply方法,Prod繼承Function(py)類,Function(py)類的定義可以在torch/autograd/function.py里面找到
class Function(with_metaclass(FunctionMeta, _C._FunctionBase, _ContextMethodMixin, _HookMixin)): # only for backward compatibility __call__ = _C._FunctionBase._do_forward @staticmethod def forward(*args, **kwargs): raise NotImplementedError @staticmethod def backward(*grad_outputs): raise NotImplementedError
Function(py)類有兩個待子類實現的方法,但也沒有定義apply方法,它的父類是以一種元類的形式動態產生的,這里不詳細解釋元類的具體產生知識,但我們至少看代碼上可以知道,可以到FunctionMeta、_C._FunctionBase、_ContextMethodMixin、_HookMixin幾個類里面去找apply方法。其中除了第二個類以外,另外三個都在python代碼中進行定義(和Function(py)類都在同一個文件內),很容易就發現都沒有定義apply方法。_FunctionBase類此前第一部分詳細提過,由C++源代碼包裝而得。
在這里,先暫緩一下apply方法的追蹤,分神看一看FunctionMeta這個元類的實現
class FunctionMeta(type): """Function metaclass. This metaclass sets up the following properties: _is_legacy: True if forward is not defined as a static method. _backward_cls: The Function class corresponding to the differentiated version of this function (which is generated on the fly by this metaclass). """ def __init__(cls, name, bases, attrs): for super_cls in cls.mro(): forward = super_cls.__dict__.get('forward') if forward is not None: has_static_forward = isinstance(forward, staticmethod) or isinstance(forward, classmethod) break setattr(cls, '_is_legacy', not has_static_forward) # old-style functions if not has_static_forward: return super(FunctionMeta, cls).__init__(name, bases, attrs) backward_fn = type(name + 'Backward', (BackwardCFunction,), {'_forward_cls': cls}) setattr(cls, '_backward_cls', backward_fn) return super(FunctionMeta, cls).__init__(name, bases, attrs)
這個元類動態給生成的類設置了一個_backward_cls屬性,這個屬性的值是backward_fn,而backward_fn又是一個動態生成的類,這個類由type函數創建,第一個參數是類名,第二個參數是繼承的父類集合,第三個參數是類屬性名與具體對象的映射字典。我們在交互界面Prod來進行一些操作驗證一下。
from torch.autograd._functions import Prod Prod._backward_cls Out[2]: torch.autograd.function.ProdBackward Prod._backward_cls() Out[3]: <torch.autograd.function.ProdBackward at 0x1caee4ae450> Prod._backward_cls_.forward_cls Out[4]: torch.autograd._functions.reduce.Prod Prod._backward_cls._forward_cls.apply? Docstring: <no docstring> Type: builtin_function_or_method Prod._backward_cls.apply? Signature: Prod._backward_cls.apply(self, *args) Docstring: <no docstring> File: c:\anaconda3\lib\site-packages\torch\autograd\function.py Type: function
可以看到,FunctionMeta元類動態地給Prod類生成了一個_backward_cls屬性,這個屬性的值是一個類,類的名字叫ProdBackward, 符合源代碼中類名為name + 'Backward’的構造形式。將Prod._backward_cls實例化以后可以得到一個對應的對象。因為這個動態類建立的時候給它定義了一個_forward_cls的屬性,映射回類本身,所以Prod._backward_cls._forward_cls又能訪問回Prod類。
那像這類動態生成的Backward類的繼承關系又是怎么樣的呢?從type的第二個參數可以看到,它的父類是BackwardCFunction,定義在和Function類同樣的文件夾里。
class BackwardCFunction(_C._FunctionBase, _ContextMethodMixin, _HookMixin): _is_legacy = False def apply(self, *args): return self._forward_cls.backward(self, *args)
可以看到,BackwardCFunction類除了FunctionMeta類以外,也繼承了另外三個重要的類,然后覆寫了apply方法。所以ProdBackward和Prod的整體操作都是很像的,主要區別只有兩個。其一是Prod在創建的時候,會動態生成一個PordBackward類;其二是,Prod從_FunctionBase繼承apply方法,而ProdBackward繼承的apply則是其是父類覆寫后的apply。
3.2 C++層代碼 - THPFunction_apply
於是,現在回過頭來繼續追蹤Prod的apply方法。_FunctionBase在torch/csrc/autograd/python_function.cpp里被注冊,在同樣文件里,變量THPFunction_methods指明了C++函數與python對象方法的映射,可以看到_FunctionBase.apply相當於就是調用了THPFunction_apply函數。看一下這個函數的具體定義
PyObject *THPFunction_apply(PyObject *cls, PyObject *_inputs) { HANDLE_TH_ERRORS THPObjectPtr backward_cls(PyObject_GetAttrString(cls, "_backward_cls")); if (!backward_cls) return NULL; THPObjectPtr ctx_obj(PyObject_CallFunctionObjArgs(backward_cls, NULL)); if (!ctx_obj) return NULL; THPFunction* ctx = (THPFunction*)ctx_obj.get(); // Prepare inputs and allocate context (grad fn) auto info_pair = unpack_input<false>(_inputs); auto& unpacked_input = info_pair.first; auto& input_info = info_pair.second; bool is_volatile = input_info.flags.is_volatile; ctx->cdata.set_flags(std::move(input_info.flags)); ctx->needs_input_grad = input_info.needs_input_grad.release(); ctx->is_variable_input = new std::vector<bool>(std::move(input_info.is_variable_input)); // Prepend ctx to tensor_input, in preparation for static method call auto num_args = PyTuple_GET_SIZE(_inputs); THPObjectPtr ctx_tensor_input(PyTuple_New(num_args + 1)); PyTuple_SET_ITEM(ctx_tensor_input.get(), 0, ctx_obj.release()); for (int i = 0; i < num_args; ++i) { PyObject *arg = PyTuple_GET_ITEM(unpacked_input.tensor_input.get(), i); Py_INCREF(arg); PyTuple_SET_ITEM(ctx_tensor_input.get(), i + 1, arg); } // Call forward THPObjectPtr forward_fn(PyObject_GetAttrString(cls, "forward")); if (!forward_fn) return NULL; THPObjectPtr tensor_outputs(PyObject_CallObject(forward_fn, ctx_tensor_input)); if (!tensor_outputs) return NULL; return process_outputs(ctx, unpacked_input, std::move(tensor_outputs), is_volatile); END_HANDLE_TH_ERRORS }
函數比較長,源代碼書寫的時候也明確地分成了幾個部分,每寫完一個部分就用一個空行來分隔,一共可以算作分了五段。每段代碼都簡單總結如下:
a.傳入參數PyObject* cls和PyObject* _inputs。cls代表調用這個函數的類本身,即python中的Prod類。_inputs則代表這個函數在python中調用時的所有參數,以tuple的形式打包成_inputs,對於最初的示例代碼而言,則相當於python中的(x, 1, True)。第一段寫成類python代碼時,形如
ctx = Prod._backward_cls()
b.參數通過unpack_input函數解析參數_inputs,得到unpacked_input和input_info兩個對象,該函數與返回值的定義將稍后細查,先略帶劇透地進行類比,unpacked_input.tensor_input的值類似於在python中的(x.data, 1, True),把原來的input_的值(x, 1, True)中的Variable換成了Tensor,其余參數保持一致。
c.重處理輸入參數,將這一段翻譯成類python代碼如下所示,基本目的就是生成一個比_inputs長1的tuple,把這個tuple的首位賦值為ctx,剩余的按unpacked_input.tensor_input依序填入。
num_args = len(_inputs) ctx_tensor_input = (None, ) * (num_args + 1) ctx_tensor_input[0] = ctx for i in range(num_args): arg = unpacked_input.tensor_input[i] ctx_tensor_input[i + 1] = arg
d.運行forward計算,得到一個Tensor對象,這一段寫成類python代碼如下所示。PyObject_CallObject調用了python代碼中的函數,也就是Prod類中的forward方法,稍后將回來追蹤此方法的實現。
forward_fn = Prod.forward tensor_outputs = forward_fn(*ctx_tensor_input) #ctx_tensor_input = (ctx, x.data, 1, True)
e.調用process_outputs,將返回的Tensor對象包裝成一個Variable對象,並返回
在五個部分中,a、c部分相對簡單,b、d、e部分均調用了其它效果不甚顯然的函數,計算圖是在哪一部分形成的連接呢?以下進行詳細的解析
3.3 C++層代碼 - unpack_input與參數解析
unpack_input是THPFunction_apply的b部分解析參數_inputs的重要函數,它返回由一個UnpackedInput實例,和一個InputFlags實例組成的std::pair,這兩個類的定義恰好在unpack_input的定義之前
struct UnpackedInput { PyObject *raw_input; THPObjectPtr tensor_input; variable_list input_vars; }; struct InputFlags { FunctionFlags flags; THPObjectPtr needs_input_grad; std::vector<bool> is_variable_input; };
從類及類成員變量的命名上可猜測,UnpackedInput主要保存_inputs解析后的數據,InputFlags類通過逐個解析_inputs的分量,來判斷每個變量的求導標識。在這種基本先驗思想指導之下,查看unpack_input的代碼
template<bool enforce_variables> std::pair<UnpackedInput, InputFlags> unpack_input(PyObject *args) { UnpackedInput unpacked; InputFlags flags; auto num_args = PyTuple_GET_SIZE(args); unpacked.tensor_input = PyTuple_New(num_args); flags.needs_input_grad = PyTuple_New(num_args); for (int i = 0; i < num_args; i++) { PyObject *arg = PyTuple_GET_ITEM(args, i); PyObject *new_arg; bool is_variable = THPVariable_Check(arg); flags.is_variable_input.push_back(is_variable); if (!is_variable) { if (enforce_variables) { THPUtils_setError("expected a Variable argument, but got %s", THPUtils_typename(arg)); throw python_error(); } Py_INCREF(arg); new_arg = arg; Py_INCREF(Py_False); PyTuple_SET_ITEM(flags.needs_input_grad.get(), i, Py_False); } else { THPVariable* variable = (THPVariable*)arg; new_arg = THPVariable_get_data(variable); unpacked.input_vars.push_back(variable->cdata); PyObject* needs_grad = variable->cdata->requires_grad ? Py_True : Py_False; Py_INCREF(needs_grad); PyTuple_SET_ITEM(flags.needs_input_grad.get(), i, needs_grad); } PyTuple_SET_ITEM(unpacked.tensor_input.get(), i, new_arg); } flags.flags = Function::flags(unpacked.input_vars); return std::make_pair(std::move(unpacked), std::move(flags)); }
由於代碼冗長,不如單獨追蹤兩個類的每個成員變量,如對於UnpackedInput類的tensor_input域而言,循環體實際如下。如果某變量不是一個THPVariable,則直接添加到unpacked.tensor_input中,如果某變量是一個THPVariable,則對unpacked.tensor_input添加這個變量的data域,相當於python中的Tensor。
template<bool enforce_variables> std::pair<UnpackedInput, InputFlags> unpack_input(PyObject *args) { .................................. for (int i = 0; i < num_args; i++) { PyObject *arg = PyTuple_GET_ITEM(args, i); PyObject *new_arg; bool is_variable = THPVariable_Check(arg) ..............................if (!is_variable) { .................................... new_arg = arg; .................................... } else { THPVariable* variable = (THPVariable*)arg; new_arg = THPVariable_get_data(variable); ................................................ } PyTuple_SET_ITEM(unpacked.tensor_input.get(), i, new_arg); } ................................................... }
對兩個類各個成員變量進行分析,可以得知解析函數后具體產生的數值,簡易但不太嚴謹的簡述如下:
UnpackedInput.tensor_input:和_inputs一樣長的tuple,如果_inputs[i]不是THPVariable,那就維持tensor_input[i] = inputs[i],如果是THPVariable,則tensor_input[i] = input[i].data。值得注意的是,雖然這個成員變量的名字叫做tensor_input,但並不代表tensor_input的每一個分量都是tensor,實際上,它有可能包括大量的計算參數。
UnpackedInput.input_vars:如果_inputs[i]是一個THPVariable,添加_inputs[i].cdata(Variable(C++)類)到input_vars中,否則忽略
InputFlags.flags:UnpackedInput.input_vars把輸入的Variable(C++)收集完畢后,調用Function::flags來判斷求導標識。基本規則按照2.1中提過的進行。
InputFlags.needs_input_grad:和_inputs一樣長的tuple,如果_inputs[i]需要求導則為PyTrue,不需要求導則為PyFalse。不需要求導的可能性有兩種,一是_inputs[i]本身就不是THPVariable,二是雖然_inputs[i]是THPVariable但其求導標識不需要計算圖對它求導。
InputFlags.is_variable_input:和inputs一樣長的vector,如果_inputs[i]是THPVariable則為True,否則為False
對於本節開始的python代碼,幾個域的值以類python的風格寫出來大致如下:
_inputs:(x, 1, True)
UnpackedInput.tensor_input:(x.data, 1, True)
UnpackedInput.input_vars:(x,)
InputFlags.needs_input_grad:(True, False, False)
InputFlags.is_variable_input:(True, False, False)
非常值得注意的是InputsFlags.flags這個域,它調用了Function::flags,看一下Function::flags的源代碼,在torch/csrc/autograd/function.cpp
auto Function::flags(const variable_list& inputs) -> FunctionFlags { int num_inputs = inputs.size(); FunctionFlags f; f.is_executable = false; f.is_volatile = false; f.next_functions.resize(num_inputs); for (int i = 0; i != num_inputs; ++i) { auto& var = inputs[i]; if (var) { f.is_executable |= var->requires_grad; f.is_volatile |= var->is_volatile; if (var->grad_fn) { f.next_functions[i] = std::make_pair<>(var->grad_fn, var->output_nr); } else { f.next_functions[i] = std::make_pair<>(var->get_grad_accumulator(), 0); } } } f.is_executable &= !f.is_volatile; return f; }
可以看到,除了一直有提到的is_volatile、requires_grad、is_executable之間的關系,很令人感興趣的一點是FunctionFlags的next_functions域也有被操作,當某個input有grad_fn屬性的時候(換句話說,不是葉節點的時候),則FunctionFlag的next_functions域的某分量會指向這個grad_fn;當某個input沒有grad_fn的時候,則next_functions的這個分量會指向一個Function(C++)對象。這是一個從FunctionFlags類實例到Fucntion(C++)類實例的連接,實際上,它已經與Function(C++)類實例與Function(C++)類實例的連接非常接近了。是在哪里做出這個轉換的呢?可以從調用完iunpack_input后,THPFunction_apply的代碼往下查閱。
PyObject *THPFunction_apply(PyObject *cls, PyObject *_inputs) { ................................................... // Prepare inputs and allocate context (grad fn) auto info_pair = unpack_input<false>(_inputs); auto& unpacked_input = info_pair.first; auto& input_info = info_pair.second; bool is_volatile = input_info.flags.is_volatile; ctx->cdata.set_flags(std::move(input_info.flags)); ctx->needs_input_grad = input_info.needs_input_grad.release(); ctx->is_variable_input = new std::vector<bool>(std::move(input_info.is_variable_input)); .................................................... }
代碼摘抄區域的第5行,ctx->cdata調用下成員函數set_flags,參數是由_inputs經過unpack_input得到的input_info中的flags域。查看Function::set_flags這個方法,在torch/csrc/autograd/function.h中,顯而易見地,經過這個方法的調用,計算圖最終實現了Function(C++)與Function(C++)的連接。更詳細地,ctx是當前output即將會指向的梯度函數(在下面的process_outputs中實現),被連接的,即處於ctx的next_functions域中的,則是所有計算出output的inputs對象各種指向的梯度函數,這是一個由output梯度函數指向inputs梯度函數而形成相連的鏈條。
struct Function { ....................................... inline void set_flags(FunctionFlags&& flags) { is_executable = flags.is_executable; next_functions = std::move(flags.next_functions); } ....................................... }
3.4 python層代碼 - forward函數調用
在THPFunction_apply函數的d部分,C++代碼通過PyObject_CallObject調用python中的函數,進行前向運算。對於初始的python示例代碼,它相當於在python中調用了
Prod.forward(ctx, x.data, dim=1, keepdim=True)
觀察Prod的forward方法
class Prod(Function): @staticmethod def forward(ctx, input, dim=None, keepdim=None): ctx.dim = dim ctx.keepdim = False if keepdim is None else keepdim ctx.input_size = input.size() if dim is None: ctx.result = input.prod() ctx.save_for_backward(input) return input.new((ctx.result,)) else: if keepdim is not None: output = input.prod(dim, keepdim=keepdim) else: output = input.prod(dim) ctx.save_for_backward(input, output) return output
....................................
整個函數的定義相對明確且簡單,有兩個值得提的點。第一,輸入的input形參與返回值output,在python中均為Tensor類,而非Variable類;第二,將input和output控制為Tensor類的原因在於,底層的Tensor類已經設計好了一套數據運算方法,如果不調用基於Tensor的方法,而在Variable上建立新的運算規則,不利於分層維護的原則,也會造成較大的資源浪費。
3.4 C++層代碼 - process_outputs
THPFunction_apply的d部分,數據通過前向運算,得到了output,但是這個output只是一個Tensor,還未被包裝為Variable。從計算圖的角度看,現在雖然已經有了從Function節點到Function節點的連接,但是從Variable節點到Function節點的連接卻還未建立。e部分,process_outputs就是處理這種后續工作的。
PyObject *THPFunction_apply(PyObject *cls, PyObject *_inputs) { HANDLE_TH_ERRORS ..................................
return process_outputs(ctx, unpacked_input, std::move(tensor_outputs), is_volatile); END_HANDLE_TH_ERRORS }
看process_outputs函數的定義
PyObject* process_outputs(THPFunction* grad_fn, const UnpackedInput& unpacked, THPObjectPtr&& raw_output, bool is_volatile) { bool unpack_output = _ensure_tuple(raw_output); auto num_outputs = PyTuple_GET_SIZE(raw_output.get()); THPObjectPtr outputs(PyTuple_New(num_outputs)); if (!outputs) throw python_error(); grad_fn->cdata.num_inputs = num_outputs; // Initialize t2var map t2var_type t2var; for (auto& c_var : unpacked.input_vars) { THPVariable* py_var = (THPVariable*)c_var->pyobj; t2var.emplace(py_var->data, py_var); } std::unordered_set<PyObject *> dirty_inputs; _mark_dirty(grad_fn, t2var, dirty_inputs); _wrap_outputs(grad_fn, t2var, dirty_inputs, raw_output, outputs, is_volatile); _join_version_counters(grad_fn, t2var); if (grad_fn->cdata.is_executable) { _mark_non_differentiable(grad_fn, t2var); _save_variables(grad_fn, t2var); } else { // Remove unnecessary attributes Py_XDECREF(grad_fn->to_save); grad_fn->to_save = NULL; Py_XDECREF(grad_fn->non_differentiable); grad_fn->non_differentiable = NULL; } // Unpack the output, unless .forward() returned a tuple if (unpack_output) { PyObject *output = PyTuple_GET_ITEM(outputs.get(), 0); Py_INCREF(output);_sa return output; } return outputs.release(); }
看到這個函數一樣繼續往下調用了很多其它的函數,按順序包括_mark_dirty、_wrap_outputs、_join_version_counters、_mark_non_differentiable、_save_variables。這幾個函數可以分成兩類,_wrap_outputs作為核心的包裝函數屬於一類,其它的四個函數屬於另一類。轉到_mark_diry等四個函數的定義一看,函數內部開始時,都會先檢查某個grad_fn的成員變量。
static void _mark_dirty(THPFunction *self, t2var_type &t2var, std::unordered_set<PyObject *> &dirty_inputs) { // Increase versions of modified tensors if (!self->dirty_tensors) return; ................................ } static void _save_variables(THPFunction* self, t2var_type &t2var) { if (!self->to_save) return; ................................. } static void _join_version_counters(THPFunction *self, t2var_type &t2var) { if (!self->shared_pairs) return; ................................ } static void _mark_non_differentiable(THPFunction *self, t2var_type &t2var) { if (!self->non_differentiable) return; ................................ }
這些成員變量都是從哪里來的呢?檢查THPFunction的類定義dirty_tensors、to_save、shared_pairs、non_fidderentiable均有PyObject*,可以在python中訪問到。在torch/autograd/function.py的python源代碼中,可以看見一個 _ContextMethodMixin類,它是Function類的元類組成部分之一,它的方法操作這這些變量的賦值。另外,從這些方法的注釋里可以發現,這些方法只允許在Funciton(py)類的forward方法中被調用,各自解決一些特定的問題,如mark_dirty處理數據原地操作后如何正確建立計算圖的問題。
class _ContextMethodMixin(object): def save_for_backward(self, *tensors): """Saves given tensors for a future call to :func:`~Function.backward`. **This should be called at most once, and only from inside the** :func:`forward` **method.** Later, saved tensors can be accessed through the :attr:`saved_tensors` attribute; or, if the corresponding Variable is needed (e.g. for double backwards), those can be accessed through the :attr:`saved_variables` attribute. Before returning them to the user, a check is made, to ensure they weren't used in any in-place operation that modified their content. Arguments can also be ``None``. """ self.to_save = tensors def mark_dirty(self, *args): """Marks given tensors as modified in an in-place operation. **This should be called at most once, only from inside the** :func:`forward` **method, and all arguments should be inputs.** Every tensor that's been modified in-place in a call to :func:`forward` should be given to this function, to ensure correctness of our checks. It doesn't matter whether the function is called before or after modification. """ self.dirty_tensors = args def mark_shared_storage(self, *pairs): """Marks that given pairs of distinct tensors are sharing storage. **This should be called at most once, only from inside the** :func:`forward` **method, and all arguments should be pairs of (input, output).** If some of the outputs are going to be tensors sharing storage with some of the inputs, all pairs of (input_arg, output_arg) should be given to this function, to ensure correctness checking of in-place modification. The only exception is when an output is exactly the same tensor as input (e.g. in-place ops). In such case it's easy to conclude that they're sharing data, so we don't require specifying such dependencies. This function is not needed in most functions. It's primarily used in indexing and transpose ops. """ self.shared_pairs = pairs def mark_non_differentiable(self, *args): """Marks outputs as non-differentiable. **This should be called at most once, only from inside the** :func:`forward` **method, and all arguments should be outputs.** This will mark outputs as not requiring gradients, increasing the efficiency of backward computation. You still need to accept a gradient for each output in :meth:`~Function.backward`, but it's always going to be ``None``. This is used e.g. for indices returned from a max :class:`Function`. """ self.non_differentiable = args
因為C++里對應的處理函數相對比較繁瑣復雜,這里就不一一詳解了。僅以當前舉例的Prod類為主,它的forward函數只調用了save_for_backward函數,則在幾個功能函數中,只看對應的其中一個函數。
了解這些功能函數,可以注意看_wrap_outputs這個從Tensor向Variable包裝的核心函數。
static void _wrap_outputs(THPFunction *self, t2var_type &t2var,
std::unordered_set<PyObject *> &dirty_inputs, PyObject *raw_output,
PyObject *outputs, bool is_volatile){
// Wrap outputs in Variables auto cdata = is_volatile ? nullptr : THPFunction_asFunction(self); Py_ssize_t num_outputs = PyTuple_GET_SIZE(raw_output); if (self->cdata.is_executable) { self->output_info = new std::vector<output_info_type>(); self->output_info->reserve(num_outputs); } for (int i = 0; i < num_outputs; i++) { PyObject *output = PyTuple_GET_ITEM(raw_output, i); THPVariable *output_var; auto it = t2var.find(output); if (it == t2var.end()) { // A completely new tensor - just wrap it and continue if (is_volatile) { output_var = (THPVariable*)THPVariable_NewVolatile(output); } else { output_var = (THPVariable*)THPVariable_NewWithFunction(output, cdata); } } else {
.....................................
.....................................
.....................................
//long long long code
} if (!output_var) throw python_error(); if (self->output_info) { auto& output_tensor = *output_var->cdata->data; self->output_info->emplace_back( (PyObject *)getPyTypeObject(output_tensor), output_tensor.getDevice(), output_tensor.sizes() ); } t2var[output] = output_var; output_var->cdata->output_nr = i; PyTuple_SET_ITEM(outputs, i, (PyObject*)output_var); } }
第一個參數self,在當前情況下,實參為一路傳進來的THPFunction* ctx;第二個參數t2var,是一個由輸入Tensor到對應輸入Variable的無序映射,在process_outputs作用域中生成,在當前情況下,t2var的值類似於python中的字典{x.data: x};第三個參數dirty_inputs,因為在_mark_dirty中直接return了,所以是一個空的集合,在本次調用中也不起作用;第四個參數raw_output,就是由python層forward方法計算得到的輸出Tensor,需要進行包裝的數據;第五個參數outputs用於返回包裝好后的值;第六個參數is_volatile作為求導選項標識傳入。
函數開始先做一些is_volatile、is_executable的檢查,如果確實需要求導,則開一個輸出變量大小的空間,然后進入循環。