TensorFlow 2 默認的即時執行模式(Eager Execution)為我們帶來了靈活及易調試的特性,但為了追求更快的速度與更高的性能,我們依然希望使用 TensorFlow 1.X 中默認的圖執行模式(Graph Execution)。此時,TensorFlow 2 為我們提供了 tf.function
模塊,結合 AutoGraph 機制,使得我們僅需加入一個簡單的 @tf.function
修飾符,就能輕松將模型以圖執行模式運行。
實現方式
只需要將我們希望以圖執行模式運行的代碼封裝在一個函數內,並在函數前加上 @tf.function
即可。
import tensorflow as tf
from tensorflow import keras
import numpy as np
from matplotlib import pyplot as plt
import time
np.random.seed(42) # 設置numpy隨機數種子
tf.random.set_seed(42) # 設置tensorflow隨機數種子
# 生成訓練數據
x = np.linspace(-1, 1, 100)
x = x.astype('float32')
y = x * x + 1 + np.random.rand(100)*0.1 # y=x^2+1 + 隨機噪聲
x_train = np.expand_dims(x, 1) # 將一維數據擴展為二維
y_train = np.expand_dims(y, 1) # 將一維數據擴展為二維
plt.plot(x, y, '.') # 畫出訓練數據
def create_model():
inputs = keras.Input((1,))
x = keras.layers.Dense(10, activation='relu')(inputs)
outputs = keras.layers.Dense(1)(x)
model = keras.Model(inputs=inputs, outputs=outputs)
return model
model = create_model() # 創建一個模型
loss_fn = keras.losses.MeanSquaredError() # 定義損失函數
optimizer = keras.optimizers.SGD() # 定義優化器
@tf.function # 將訓練過程轉化為圖執行模式
def train():
with tf.GradientTape() as tape:
y_pred = model(x_train, training=True) # 前向傳播,注意不要忘了training=True
loss = loss_fn(y_train, y_pred) # 計算損失
tf.summary.scalar("loss", loss, epoch+1) # 將損失寫入tensorboard
grads = tape.gradient(loss, model.trainable_variables) # 計算梯度
optimizer.apply_gradients(zip(grads, model.trainable_variables)) # 使用優化器進行反向傳播
return loss
epochs = 1000
begin_time = time.time() # 訓練開始時間
for epoch in range(epochs):
loss = train()
print('epoch:', epoch+1, '\t', 'loss:', loss.numpy()) # 打印訓練信息
end_time = time.time() # 訓練結束時間
print("訓練時長:", end_time-begin_time)
# 預測
y_pre = model.predict(x_train)
# 畫出預測值
plt.plot(x, y_pre.squeeze())
plt.show()
通過實驗得出結論:如果不使用@tf.function
,那么訓練時間大約為3秒。如果使用@tf.function
,訓練時間僅需要0.5秒。快了很多倍。
內在原理
使用@tf.function
的函數在執行時會生成一個計算圖,里面的操作就是計算圖的每個節點。下次調用相同的函數,且參數類型相同時,則會直接使用這個計算圖計算。若函數名不同或參數類型不同時,則會另外生成一個新的計算圖。
注意點
建議在函數內只使用 TensorFlow 的原生操作,不要使用過於復雜的 Python 語句,函數參數最好只包括 TensorFlow 張量或 NumPy 數組。
-
因為只有tf的原生操作才會在計算圖中生產節點。(如python的原生
print()
函數不會生成節點,而tensorflow的tf.print()
會) -
對於Tensorflow張量或Numpy數組作為參數的函數,只要類型相同便可重用之前的計算圖。而對於python原聲數據(如原生的整數、浮點數 1,1.5等)必須參數的值一模一樣才會重用之前的計算圖,否則的話會創建新的計算圖。
另外,一般而言,當模型由較多小的操作組成的時候, @tf.function
帶來的提升效果較大。而當模型的操作數量較少,但單一操作均很耗時的時候,則 @tf.function
帶來的性能提升不會太大。