訓練模型時,很多事情一開始都無法預測。比如之前我們為了找出迭代多少輪才能得到最佳驗證損失,可能會先迭代100次,迭代完成后畫出運行結果,發現在中間就開始過擬合了,於是又重新開始訓練。
類似的情況很多,於是我們想要實時監測訓練動態,並能根據訓練情況及時對模型采取一定的措施。Keras中的回調函數和tf的TensorBoard就是為此而生。
Keras回調函數
回調函數(callbacks)是在調用fit時傳入模型的一個對象,它在訓練過程中的不同時間點都會被模型調用。它可以訪問關於模型狀態和性能的所有可用數據,還可以采取行動:中斷訓練、保存模型、加載一組不同的權重或者改變模型的狀態。也就是說,之前在訓練模型的過程中,我們不知道模型的實時狀態,因此為了更好的監測和控制模型的訓練過程,我們派出了一個特派員——回調函數,它可以根據情況記錄、反饋或者采取措施。我們熟悉的訓練進度條和fit返回的history都是回調函數,只不過它倆因為太常用,所以被單獨拎出來。
fit和fit_generator函數都提供了callbacks接口。常用的回調函數有:
- ModelCheckpoint(在每輪過后保存當前模型);
- EarlyStopping(如果監控參數得不到改善就中斷訓練);
- LearningRateScheduler(在訓練過程中動態調整學習率);
- ReduceLROnPlateau(如果驗證表現得不到改善,可以用它降低學習率,跳出局部最小值);
- CSVLogger(將每個epoch的結果寫入CSV文件)。
- 其他回調函數,也可以根據需要自行編寫。
應用示例:
from keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
#fit提供callbacks接口,接收一個回調函數列表,可將任意個回調函數傳入模型中
callback_lists = []
callback_lists.append(EarlyStopping(monitor = 'acc', #監控模型的驗證精度
patience = 1)) #如果精度在多於一輪的時間(即兩輪)內不再改善,就中斷訓練
callback_lists.append(ModelCheckpoint(filepath = 'my_model.h5', #目標文件的保存路徑
monitor = 'val_loss', #監控驗證損失
save_best_only = True)) #只保存最佳模型
callback_lists.append(ReduceLROnPlateau(monitor = 'val_loss', #監控模型的驗證損失
factor = 0.1, #觸發時將學習率乘以系數0.1
patience = 10) #若驗證損失在10輪內都沒有改善,則觸發該回調函數
#由於回調函數要監控驗證損失和驗證精度,所以在調用fit時需要傳入validation_data
model.fit(x, y, epochs = 10, batch_size = 32,
callbacks = callbacks_list,
validation_data = (x_val, y_val))
TensorBoard:實時可視化工具
TensorBoard是內置於TensorFlow中基於瀏覽器的可視化工具,安裝TensorFlow時會自動安裝這個工具。簡單來說,它就是把訓練過程數據寫入文件,然后用瀏覽器查看的工具。在Keras中,它也被包裝成一個回調函數。
示例如下:
#引入Tensorboard
from keras.callbacks import TensorBoard
#定義回調函數列表,現在只放一個簡單的TensorBoard
log_path = './logs' #指定TensorBoard讀取的文件路徑,可以新建一個
callback_lists = [TensorBoard(log_dir=log_path, histogram_freq=1)]
#模型調用fit時,通過回調函數接口傳入
model.fit(...inputs and parameters..., callbacks=callback_lists)
為了在訓練的過程中可視化各項指標,需要自己在終端啟動TensorBoard。
打開終端的方式有兩種:一種是系統自帶的終端cmd;另一種是在Anaconda Prompt終端。選擇用哪種終端打開,根據當時安裝tensorflow時用的終端方式。我試了下cmd,總是出錯,但在Anaconda Prompt終端就能正常啟動。
啟動方式:在終端輸入 tensorboard --logdir=C:\Users...\logs (自己文件的路徑),就會返回一行信息,包含了一個http網址。這個地址一般是不會改變的,在瀏覽器中輸入提示的http地址,即可查看模型的訓練過程和相關狀態,如下圖所示。
Reference
書籍:Python深度學習