0. 簡介
記錄使用 tf.keras 時遇到的各種問題。
tf.keras 是 keras 的未來,keras 作為 TensorFlow 的高級 API,大大簡化 TensorFlow 代碼的編寫過程。
Keras(單獨的)、TensorFlow 1.x 和 TensorFlow 2.0 的 keras API 變化不大,手冊可以通用。
tf.keras 快速入門:初學者的 TensorFlow 2.0 教程 | TensorFlow Core
1. 安裝
1.1 安裝 CUDA 和 cuDNN
【tf.keras】Linux 非 root 用戶安裝 CUDA 和 cuDNN
2. 數據集
2.1 使用 tensorflow_datasets 導入公共數據集
【tf.keras】tensorflow datasets,tfds
2.2 數據集過大導致內存溢出
【tf.keras】在 cifar 上訓練 AlexNet,數據集過大導致 OOM
2.3 加載 cifar10 數據時報錯
3. 評價指標
3.1 實現 F1 socre、precsion、recall
在整個數據集而不是單個 batch 上實現 F1 socre、precsion、recall 等評價指標:
【tf.keras】實現 F1 score、precision、recall 等 metric
4. 優化器
4.1 AdamW 優化器示例程序
【tf.keras】AdamW: Adam with Weight decay
4.2 tf.keras 1.x 在使用 learning rate decay 時不要使用 tf.train 內的優化器
【tf.keras】tf.keras使用tensorflow中定義的optimizer
5. 模型
5.1 模型復現
【tf.keras】tf.keras模型復現
(注意:在CPU上訓練才能完全復現模型)
5.2 加載 AlexNet 預訓練模型
【tf.keras】tf.keras加載AlexNet預訓練模型
5.3 循環訓練模型導致 OOM
6. TensorFlow API 變化
5.1 TF 1.x 到 TF 2.0 API 變化,隨機種子、動態分配顯存
【tf.keras】TensorFlow 1.x 到 2.0 的 API 變化
5.2 TF 2.1 API 變化
TensorFlow 2.1 將 fit_generator(), evaluate_generator(), predict_generator() 等函數分別合並到 fit(),evaluate(),predict() 里。