參考文獻:https://blog.csdn.net/guanxs/article/details/102471843
在TensorFlow 1.x靜態圖時代,我們知道每個靜態圖都有兩部分,一部分是前向圖,另一部分是反向圖。反向圖就是用來計算梯度的,用在整個訓練過程中。而TensorFlow 2.0默認是eager模式,每行代碼順序執行,沒有了構建圖的過程(也取消了control_dependency的用法)。但也不能每行都計算一下梯度吧?計算量太大,也沒必要。因此,需要一個上下文管理器(context manager)來連接需要計算梯度的函數和變量,方便求解同時也提升效率。
舉個例子:計算y=x^2在x = 3時的導數:
x = tf.constant(3.0) with tf.GradientTape() as g: g.watch(x) y = x * x dy_dx = g.gradient(y, x) # y’ = 2*x = 2*3 = 6
例子中的watch函數把需要計算梯度的變量x加進來了。GradientTape默認只監控由tf.Variable創建的traiable=True屬性(默認)的變量。上面例子中的x是constant,因此計算梯度需要增加g.watch(x)函數。當然,也可以設置不自動監控可訓練變量,完全由自己指定,設置watch_accessed_variables=False就行了(一般用不到)。
GradientTape也可以嵌套多層用來計算高階導數,例如:
x = tf.constant(3.0) with tf.GradientTape() as g: g.watch(x) with tf.GradientTape() as gg: gg.watch(x) y = x * x dy_dx = gg.gradient(y, x) # y’ = 2*x = 2*3 =6 d2y_dx2 = g.gradient(dy_dx, x) # y’’ = 2
另外,默認情況下GradientTape的資源在調用gradient函數后就被釋放,再次調用就無法計算了。所以如果需要多次計算梯度,需要開啟persistent=True屬性,例如:
x = tf.constant(3.0) with tf.GradientTape(persistent=True) as g: g.watch(x) y = x * x z = y * y dz_dx = g.gradient(z, x) # z = y^2 = x^4, z’ = 4*x^3 = 4*3^3 dy_dx = g.gradient(y, x) # y’ = 2*x = 2*3 = 6 del g # 刪除這個上下文tape
最后,一般在網絡中使用時,不需要顯式調用watch函數,使用默認設置,GradientTape會監控可訓練變量,例如:
with tf.GradientTape() as tape: predictions = model(images) loss = loss_object(labels, predictions) gradients = tape.gradient(loss, model.trainable_variables)
這樣即可計算出所有可訓練變量的梯度,然后進行下一步的更新。對於TensorFlow 2.0,推薦大家使用這種方式計算梯度,並且可以在eager模式下查看具體的梯度值。
根據上面的例子說一下tf.GradientTape這個類的常見的屬性和函數,更多的可以去官方文檔來看。
__init__(persistent=False,watch_accessed_variables=True)
作用:創建一個新的GradientTape
參數:
persistent: 布爾值,用來指定新創建的gradient tape是否是可持續性的。默認是False,意味着只能夠調用一次gradient()函數。
watch_accessed_variables: 布爾值,表明這個gradien tap是不是會自動追蹤任何能被訓練(trainable)的變量。默認是True。要是為False的話,意味着你需要手動去指定你想追蹤的那些變量。
比如在上面的例子里面,新創建的gradient tape設定persistent為True,便可以在這個上面反復調用gradient()函數。
watch(tensor)
作用:確保某個tensor被tape追蹤
參數:
tensor: 一個Tensor或者一個Tensor列表
gradient(target,sources,output_gradients=None,unconnected_gradients=tf.UnconnectedGradients.NONE)
作用:根據tape上面的上下文來計算某個或者某些tensor的梯度
參數:
target: 被微分的Tensor或者Tensor列表,你可以理解為經過某個函數之后的值
sources: Tensors 或者Variables列表(當然可以只有一個值). 你可以理解為函數的某個變量
output_gradients: a list of gradients, one for each element of target. Defaults to None.
unconnected_gradients: a value which can either hold ‘none’ or ‘zero’ and alters the value which will be returned if the target and sources are unconnected. The possible values and effects are detailed in ‘UnconnectedGradients’ and it defaults to ‘none’.
返回:
一個列表表示各個變量的梯度值,和source中的變量列表一一對應,表明這個變量的梯度。