【tf.keras】TensorFlow 1.x 到 2.0 的 API 變化


TensorFlow 2.0 版本將 keras 作為高級 API,對於 keras boy/girl 來說,這就很友好了。tf.keras 從 1.x 版本遷移到 2.0 版本,需要注意幾個地方。

1. 設置隨機種子

import tensorflow as tf

# TF 1.x
tf.set_random_seed(args.seed)
# TF 2.0
tf.random.set_seed(args.seed)

2. 設置並行線程數和動態分配顯存

import tensorflow as tf
from tensorflow.python.keras import backend as K

import os
# 將程序限定在一塊GPU上
os.environ["CUDA_VISIBLE_DEVICES"] = '0'

# TF 1.x
config = tf.ConfigProto(intra_op_parallelism_threads=1,
                         inter_op_parallelism_threads=1)
config.gpu_options.allow_growth = True  # 不全部占滿顯存, 按需分配
K.set_session(tf.Session(config=config))

# TF 2.0,由於之前限定了GPU可見范圍,這里只能看到0號GPU
gpus = tf.config.experimental.list_physical_devices(device_type='GPU')
print(gpus)
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, enable=True)

3. model.compile() 中設置 metrics=['acc'] 或者 ['accuracy'],會影響 model.fit() 生成的 log,callbacks.ModelCheckpoint 需要對應填 val_acc 或者 val_accuracy:

from tensorflow.python.keras import callbacks

# TF 2.0, acc and val_acc
model.compile(loss='categorical_crossentropy',
              optimizer='adam',
              metrics=['acc'])
ck_callback = callbacks.ModelCheckpoint('./model.h5', monitor='val_acc', mode='max',
                                            verbose=1, save_best_only=True, save_weights_only=True)

# TF 2.0, accuracy and val_accuracy
model.compile(loss='categorical_crossentropy',
              optimizer='adam',
              metrics=['accuracy'])
ck_callback = callbacks.ModelCheckpoint('./model.h5', monitor='val_accuracy', mode='max',
                                            verbose=1, save_best_only=True, save_weights_only=True)

4. 舍棄 model.fit_generator() 函數

model.fit_generator() 函數在 TF 2.x 中合並到 model.fit() 函數中,並且在 TF 2.0 版本,該函數有問題,不能很好利用 GPU,訓練速度很慢:
Performance: Training is much slower in TF v2.0.0 VS v1.14.0 when using Tf.Keras and model.fit_generator #33024

TF 2.0 版本的 model.fit() 在傳入 generator 時需要手動設置 model.fit(shuffle=False)。

解決辦法:直接使用 model.fit() 函數,並且升級到 TF 2.1。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM