tf.keras的回調函數實際上是一個類,一般是在model.fit時作為參數指定,用於控制在訓練過程開始或者在訓練過程結束,在每個epoch訓練開始或者訓練結束,在每個batch訓練開始或者訓練結束時執行一些操作,例如收集一些日志信息,改變學習率等超參數,提前終止訓練過程等等。
同樣地,針對model.evaluate或者model.predict也可以指定callbacks參數,用於控制在評估或預測開始或者結束時,在每個batch開始或者結束時執行一些操作,但這種用法相對少見。
大部分時候,keras.callbacks子模塊中定義的回調函數類已經足夠使用了,如果有特定的需要,我們也可以通過對keras.callbacks.Callbacks實施子類化構造自定義的回調函數。
所有回調函數都繼承至 keras.callbacks.Callbacks基類,擁有params和model這兩個屬性。
其中params 是一個dict,記錄了 training parameters (eg. verbosity, batch size, number of epochs...).
model即當前關聯的模型的引用。
此外,對於回調類中的一些方法如on_epoch_begin,on_batch_end,還會有一個輸入參數logs, 提供有關當前epoch或者batch的一些信息,並能夠記錄計算結果,如果model.fit指定了多個回調函數類,這些logs變量將在這些回調函數類的同名函數間依順序傳遞。
一,內置回調函數
-
BaseLogger: 收集每個epoch上metrics在各個batch上的平均值,對stateful_metrics參數中的帶中間狀態的指標直接拿最終值無需對各個batch平均,指標均值結果將添加到logs變量中。該回調函數被所有模型默認添加,且是第一個被添加的。
-
History: 將BaseLogger計算的各個epoch的metrics結果記錄到history這個dict變量中,並作為model.fit的返回值。該回調函數被所有模型默認添加,在BaseLogger之后被添加。
-
EarlyStopping: 當被監控指標在設定的若干個epoch后沒有提升,則提前終止訓練。
-
TensorBoard: 為Tensorboard可視化保存日志信息。支持評估指標,計算圖,模型參數等的可視化。
-
ModelCheckpoint: 在每個epoch后保存模型。
-
ReduceLROnPlateau:如果監控指標在設定的若干個epoch后沒有提升,則以一定的因子減少學習率。
-
TerminateOnNaN:如果遇到loss為NaN,提前終止訓練。
-
LearningRateScheduler:學習率控制器。給定學習率lr和epoch的函數關系,根據該函數關系在每個epoch前調整學習率。
-
CSVLogger:將每個epoch后的logs結果記錄到CSV文件中。
-
ProgbarLogger:將每個epoch后的logs結果打印到標准輸出流中。
二,自定義回調函數
可以使用callbacks.LambdaCallback編寫較為簡單的回調函數,也可以通過對callbacks.Callback子類化編寫更加復雜的回調函數邏輯。
如果需要深入學習tf.Keras中的回調函數,不要猶豫閱讀內置回調函數的源代碼。
import numpy as np import pandas as pd import tensorflow as tf from tensorflow.keras import layers,models,losses,metrics,callbacks import tensorflow.keras.backend as K # 示范使用LambdaCallback編寫較為簡單的回調函數 import json json_log = open('./data/keras_log.json', mode='wt', buffering=1) json_logging_callback = callbacks.LambdaCallback( on_epoch_end=lambda epoch, logs: json_log.write( json.dumps(dict(epoch = epoch,**logs)) + '\n'), on_train_end=lambda logs: json_log.close() ) # 示范通過Callback子類化編寫回調函數(LearningRateScheduler的源代碼) class LearningRateScheduler(callbacks.Callback): def __init__(self, schedule, verbose=0): super(LearningRateScheduler, self).__init__() self.schedule = schedule self.verbose = verbose def on_epoch_begin(self, epoch, logs=None): if not hasattr(self.model.optimizer, 'lr'): raise ValueError('Optimizer must have a "lr" attribute.') try: lr = float(K.get_value(self.model.optimizer.lr)) lr = self.schedule(epoch, lr) except TypeError: # Support for old API for backward compatibility lr = self.schedule(epoch) if not isinstance(lr, (tf.Tensor, float, np.float32, np.float64)): raise ValueError('The output of the "schedule" function ' 'should be float.') if isinstance(lr, ops.Tensor) and not lr.dtype.is_floating: raise ValueError('The dtype of Tensor should be float') K.set_value(self.model.optimizer.lr, K.get_value(lr)) if self.verbose > 0: print('\nEpoch %05d: LearningRateScheduler reducing learning ' 'rate to %s.' % (epoch + 1, lr)) def on_epoch_end(self, epoch, logs=None): logs = logs or {} logs['lr'] = K.get_value(self.model.optimizer.lr)
參考:
開源電子書地址:https://lyhue1991.github.io/eat_tensorflow2_in_30_days/
GitHub 項目地址:https://github.com/lyhue1991/eat_tensorflow2_in_30_days