tf.GradientTape定義在tensorflow/python/eager/backprop.py文件中,從文件路徑也可以大概看出,GradientTape是eager模式下計算梯度用的,而eager模式(eager模式的具體介紹請參考文末鏈接)是TensorFlow 2.0的默認模式,因此tf.GradientTape是官方大力推薦的用法。下面就來具體介紹GradientTape的原理和使用。
Tape在英文中是膠帶,磁帶的含義,用在這里是由於eager模式帶來的影響。在TensorFlow 1.x靜態圖時代,我們知道每個靜態圖都有兩部分,一部分是前向圖,另一部分是反向圖。反向圖就是用來計算梯度的,用在整個訓練過程中。而TensorFlow 2.0默認是eager模式,每行代碼順序執行,沒有了構建圖的過程(也取消了control_dependency的用法)。但也不能每行都計算一下梯度吧?計算量太大,也沒必要。因此,需要一個上下文管理器(context manager)來連接需要計算梯度的函數和變量,方便求解同時也提升效率。
舉個例子:計算y=x^2在x = 3時的導數:
x = tf.constant(3.0)
with tf.GradientTape() as g:
g.watch(x)
y = x * x
dy_dx = g.gradient(y, x) # y’ = 2*x = 2*3 = 6
例子中的watch函數把需要計算梯度的變量x加進來了。GradientTape默認只監控由tf.Variable創建的traiable=True屬性(默認)的變量。上面例子中的x是constant,因此計算梯度需要增加g.watch(x)函數。當然,也可以設置不自動監控可訓練變量,完全由自己指定,設置watch_accessed_variables=False就行了(一般用不到)。
GradientTape也可以嵌套多層用來計算高階導數,例如:
x = tf.constant(3.0)
with tf.GradientTape() as g:
g.watch(x)
with tf.GradientTape() as gg:
gg.watch(x)
y = x * x
dy_dx = gg.gradient(y, x) # y’ = 2*x = 2*3 =6
d2y_dx2 = g.gradient(dy_dx, x) # y’’ = 2
另外,默認情況下GradientTape的資源在調用gradient函數后就被釋放,再次調用就無法計算了。所以如果需要多次計算梯度,需要開啟persistent=True屬性,例如:
x = tf.constant(3.0)
with tf.GradientTape(persistent=True) as g:
g.watch(x)
y = x * x
z = y * y
dz_dx = g.gradient(z, x) # z = y^2 = x^4, z’ = 4*x^3 = 4*3^3
dy_dx = g.gradient(y, x) # y’ = 2*x = 2*3 = 6
del g # 刪除這個上下文膠帶
最后,一般在網絡中使用時,不需要顯式調用watch函數,使用默認設置,GradientTape會監控可訓練變量,例如:
with tf.GradientTape() as tape:
predictions = model(images)
loss = loss_object(labels, predictions)
gradients = tape.gradient(loss, model.trainable_variables)
這樣即可計算出所有可訓練變量的梯度,然后進行下一步的更新。對於TensorFlow 2.0,推薦大家使用這種方式計算梯度,並且可以在eager模式下查看具體的梯度值。