【tensorflow2.0】回調函數callbacks


tf.keras的回調函數實際上是一個類,一般是在model.fit時作為參數指定,用於控制在訓練過程開始或者在訓練過程結束,在每個epoch訓練開始或者訓練結束,在每個batch訓練開始或者訓練結束時執行一些操作,例如收集一些日志信息,改變學習率等超參數,提前終止訓練過程等等。

同樣地,針對model.evaluate或者model.predict也可以指定callbacks參數,用於控制在評估或預測開始或者結束時,在每個batch開始或者結束時執行一些操作,但這種用法相對少見。

大部分時候,keras.callbacks子模塊中定義的回調函數類已經足夠使用了,如果有特定的需要,我們也可以通過對keras.callbacks.Callbacks實施子類化構造自定義的回調函數。

所有回調函數都繼承至 keras.callbacks.Callbacks基類,擁有params和model這兩個屬性。

其中params 是一個dict,記錄了 training parameters (eg. verbosity, batch size, number of epochs...).

model即當前關聯的模型的引用。

此外,對於回調類中的一些方法如on_epoch_begin,on_batch_end,還會有一個輸入參數logs, 提供有關當前epoch或者batch的一些信息,並能夠記錄計算結果,如果model.fit指定了多個回調函數類,這些logs變量將在這些回調函數類的同名函數間依順序傳遞。

一,內置回調函數

  • BaseLogger: 收集每個epoch上metrics在各個batch上的平均值,對stateful_metrics參數中的帶中間狀態的指標直接拿最終值無需對各個batch平均,指標均值結果將添加到logs變量中。該回調函數被所有模型默認添加,且是第一個被添加的。

  • History: 將BaseLogger計算的各個epoch的metrics結果記錄到history這個dict變量中,並作為model.fit的返回值。該回調函數被所有模型默認添加,在BaseLogger之后被添加。

  • EarlyStopping: 當被監控指標在設定的若干個epoch后沒有提升,則提前終止訓練。

  • TensorBoard: 為Tensorboard可視化保存日志信息。支持評估指標,計算圖,模型參數等的可視化。

  • ModelCheckpoint: 在每個epoch后保存模型。

  • ReduceLROnPlateau:如果監控指標在設定的若干個epoch后沒有提升,則以一定的因子減少學習率。

  • TerminateOnNaN:如果遇到loss為NaN,提前終止訓練。

  • LearningRateScheduler:學習率控制器。給定學習率lr和epoch的函數關系,根據該函數關系在每個epoch前調整學習率。

  • CSVLogger:將每個epoch后的logs結果記錄到CSV文件中。

  • ProgbarLogger:將每個epoch后的logs結果打印到標准輸出流中。

二,自定義回調函數

可以使用callbacks.LambdaCallback編寫較為簡單的回調函數,也可以通過對callbacks.Callback子類化編寫更加復雜的回調函數邏輯。

如果需要深入學習tf.Keras中的回調函數,不要猶豫閱讀內置回調函數的源代碼。

import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras import layers,models,losses,metrics,callbacks
import tensorflow.keras.backend as K 
 
# 示范使用LambdaCallback編寫較為簡單的回調函數
 
import json
json_log = open('./data/keras_log.json', mode='wt', buffering=1)
json_logging_callback = callbacks.LambdaCallback(
    on_epoch_end=lambda epoch, logs: json_log.write(
        json.dumps(dict(epoch = epoch,**logs)) + '\n'),
    on_train_end=lambda logs: json_log.close()
)
 
# 示范通過Callback子類化編寫回調函數(LearningRateScheduler的源代碼)
 
class LearningRateScheduler(callbacks.Callback):
 
    def __init__(self, schedule, verbose=0):
        super(LearningRateScheduler, self).__init__()
        self.schedule = schedule
        self.verbose = verbose
 
    def on_epoch_begin(self, epoch, logs=None):
        if not hasattr(self.model.optimizer, 'lr'):
            raise ValueError('Optimizer must have a "lr" attribute.')
        try:  
            lr = float(K.get_value(self.model.optimizer.lr))
            lr = self.schedule(epoch, lr)
        except TypeError:  # Support for old API for backward compatibility
            lr = self.schedule(epoch)
        if not isinstance(lr, (tf.Tensor, float, np.float32, np.float64)):
            raise ValueError('The output of the "schedule" function '
                             'should be float.')
        if isinstance(lr, ops.Tensor) and not lr.dtype.is_floating:
            raise ValueError('The dtype of Tensor should be float')
        K.set_value(self.model.optimizer.lr, K.get_value(lr))
        if self.verbose > 0:
            print('\nEpoch %05d: LearningRateScheduler reducing learning '
                 'rate to %s.' % (epoch + 1, lr))
 
    def on_epoch_end(self, epoch, logs=None):
        logs = logs or {}
        logs['lr'] = K.get_value(self.model.optimizer.lr)

 

參考:

開源電子書地址:https://lyhue1991.github.io/eat_tensorflow2_in_30_days/

GitHub 項目地址:https://github.com/lyhue1991/eat_tensorflow2_in_30_days


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM