使用@tf.function加快訓練速度


TensorFlow 2 默認的即時執行模式(Eager Execution)為我們帶來了靈活及易調試的特性,但為了追求更快的速度與更高的性能,我們依然希望使用 TensorFlow 1.X 中默認的圖執行模式(Graph Execution)。此時,TensorFlow 2 為我們提供了 tf.function 模塊,結合 AutoGraph 機制,使得我們僅需加入一個簡單的 @tf.function 修飾符,就能輕松將模型以圖執行模式運行。

實現方式

只需要將我們希望以圖執行模式運行的代碼封裝在一個函數內,並在函數前加上 @tf.function 即可。

import tensorflow as tf
from tensorflow import keras
import numpy as np
from matplotlib import pyplot as plt
import time

np.random.seed(42)  # 設置numpy隨機數種子
tf.random.set_seed(42)  # 設置tensorflow隨機數種子

# 生成訓練數據
x = np.linspace(-1, 1, 100)
x = x.astype('float32')
y = x * x + 1 + np.random.rand(100)*0.1  # y=x^2+1 + 隨機噪聲
x_train = np.expand_dims(x, 1)  # 將一維數據擴展為二維
y_train = np.expand_dims(y, 1)  # 將一維數據擴展為二維
plt.plot(x, y, '.')  # 畫出訓練數據


def create_model():
    inputs = keras.Input((1,))
    x = keras.layers.Dense(10, activation='relu')(inputs)
    outputs = keras.layers.Dense(1)(x)
    model = keras.Model(inputs=inputs, outputs=outputs)
    return model


model = create_model()  # 創建一個模型
loss_fn = keras.losses.MeanSquaredError()  # 定義損失函數
optimizer = keras.optimizers.SGD()  # 定義優化器


@tf.function  # 將訓練過程轉化為圖執行模式
def train():
    with tf.GradientTape() as tape:
        y_pred = model(x_train, training=True)  # 前向傳播,注意不要忘了training=True
        loss = loss_fn(y_train, y_pred)  # 計算損失
        tf.summary.scalar("loss", loss, epoch+1)  # 將損失寫入tensorboard
    grads = tape.gradient(loss, model.trainable_variables)  # 計算梯度
    optimizer.apply_gradients(zip(grads, model.trainable_variables))  # 使用優化器進行反向傳播
    return loss


epochs = 1000
begin_time = time.time()  # 訓練開始時間
for epoch in range(epochs):
    loss = train()

    print('epoch:', epoch+1, '\t', 'loss:', loss.numpy())  # 打印訓練信息
end_time = time.time()  # 訓練結束時間

print("訓練時長:", end_time-begin_time)

# 預測
y_pre = model.predict(x_train)

# 畫出預測值
plt.plot(x, y_pre.squeeze())
plt.show()

通過實驗得出結論:如果不使用@tf.function,那么訓練時間大約為3秒。如果使用@tf.function,訓練時間僅需要0.5秒。快了很多倍。

內在原理

使用@tf.function的函數在執行時會生成一個計算圖,里面的操作就是計算圖的每個節點。下次調用相同的函數,且參數類型相同時,則會直接使用這個計算圖計算。若函數名不同或參數類型不同時,則會另外生成一個新的計算圖。

注意點

建議在函數內只使用 TensorFlow 的原生操作,不要使用過於復雜的 Python 語句,函數參數最好只包括 TensorFlow 張量或 NumPy 數組。

  • 因為只有tf的原生操作才會在計算圖中生產節點。(如python的原生print()函數不會生成節點,而tensorflow的tf.print()會)

  • 對於Tensorflow張量或Numpy數組作為參數的函數,只要類型相同便可重用之前的計算圖。而對於python原聲數據(如原生的整數、浮點數 1,1.5等)必須參數的值一模一樣才會重用之前的計算圖,否則的話會創建新的計算圖。

另外,一般而言,當模型由較多小的操作組成的時候, @tf.function 帶來的提升效果較大。而當模型的操作數量較少,但單一操作均很耗時的時候,則 @tf.function 帶來的性能提升不會太大。

參考

https://tf.wiki/zh_hans/basic/tools.html#tf-function


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM