model.fit中的callbacks是做什么的


model.fit中的callbacks是做什么的

一、總結

一句話總結:

keras的callback參數可以幫助我們實現在訓練過程中的適當時機被調用。實現實時保存訓練模型以及訓練參數。

 

 

二、keras深度訓練1:fit和callback

轉自或參考:keras深度訓練1:fit和callback
http://blog.csdn.net/github_36326955/article/details/79794288

 

1. model.fit

model.fit(
    self, 
    x, 
    y, 
    batch_size=32, 
    nb_epoch=10, 
    verbose=1, 
    callbacks=[], 
    validation_split=0.0, 
    validation_data=None, 
    shuffle=True, 
    class_weight=None, 
    sample_weight=None
)

其中:

  1. x為輸入數據。如果模型只有一個輸入,那么x的類型是numpy array,如果模型有多個輸入,那么x的類型應當為list,list的元素是對應於各個輸入的numpy array。如果模型的每個輸入都有名字,則可以傳入一個字典,將輸入名與其輸入數據對應起來。
  2. y:標簽,numpy array。如果模型有多個輸出,可以傳入一個numpy array的list。如果模型的輸出擁有名字,則可以傳入一個字典,將輸出名與其標簽對應起來。
  3. batch_size:整數,指定進行梯度下降時每個batch包含的樣本數。訓練時一個batch的樣本會被計算一次梯度下降,使目標函數優化一步。
  4. nb_epoch:整數,訓練的輪數,訓練數據將會被遍歷nb_epoch次。Keras中nb開頭的變量均為”number of”的意思
  5. verbose:日志顯示,0為不在標准輸出流輸出日志信息,1為輸出進度條記錄,2為每個epoch輸出一行記錄
  6. callbacks:list,其中的元素是keras.callbacks.Callback的對象。這個list中的回調函數將會在訓練過程中的適當時機被調用,參考回調函數
  7. validation_split:0~1之間的浮點數,用來指定訓練集的一定比例數據作為驗證集。驗證集將不參與訓練,並在每個epoch結束后測試的模型的指標,如損失函數、精確度等。
  8. validation_data:形式為(X,y)或(X,y,sample_weights)的tuple,是指定的驗證集。此參數將覆蓋validation_spilt。
  9. shuffle:布爾值,表示是否在訓練過程中每個epoch前隨機打亂輸入樣本的順序。請注意:這個shuffle並不是對整個數據集打亂順序的,而是先split出訓練集和驗證集,然后對訓練集進行shuffle。
  10. class_weight:字典,將不同的類別映射為不同的權值,該參數用來在訓練過程中調整損失函數(只能用於訓練)。該參數在處理非平衡的訓練數據(某些類的訓練樣本數很少)時,可以使得損失函數對樣本數不足的數據更加關注。
  11. sample_weight:權值的numpy array,用於在訓練時調整損失函數(僅用於訓練)。可以傳遞一個1D的與樣本等長的向量用於對樣本進行1對1的加權,或者在面對時序數據時,傳遞一個的形式為(samples,sequence_length)的矩陣來為每個時間步上的樣本賦不同的權。這種情況下請確定在編譯模型時添加了sample_weight_mode=’temporal’。12345678910111234567891011

fit函數返回一個History的對象,其History.history屬性記錄了損失函數和其他指標的數值隨epoch變化的情況,如果有驗證集的話,也包含了驗證集的這些指標變化情況。

2. callback

keras的callback參數可以幫助我們實現在訓練過程中的適當時機被調用。實現實時保存訓練模型以及訓練參數。

2.1 ModelCheckpoint
keras.callbacks.ModelCheckpoint(
    filepath, 
    monitor='val_loss', 
    verbose=0, 
    save_best_only=False, 
    save_weights_only=False, 
    mode='auto', 
    period=1
)

其中:
1. filename:字符串,保存模型的路徑
2. monitor:需要監視的值
3. verbose:信息展示模式,0或1
4. save_best_only:當設置為True時,將只保存在驗證集上性能最好的模型,一般我們都會設置為True.
5. mode:‘auto’,‘min’,‘max’之一,在save_best_only=True時決定性能最佳模型的評判准則,例如,當監測值為val_acc時,模式應為max,當檢測值為val_loss時,模式應為min。在auto模式下,評價准則由被監測值的名字自動推斷。
6. save_weights_only:若設置為True,則只保存模型權重,否則將保存整個模型(包括模型結構,配置信息等)
7. period:CheckPoint之間的間隔的epoch數

2.2 EarlyStopping
from keras.callbacksimport EarlyStopping 

keras.callbacks.EarlyStopping(
    monitor='val_loss', 
    patience=0, 
    verbose=0, 
    mode='auto'
)

model.fit(X, y, validation_split=0.2, callbacks=[early_stopping])

其中:
1. monitor:需要監視的量
2. patience:當early stop被激活(如發現loss相比上一個epoch訓練沒有下降),則經過patience個epoch后停止訓練。
3. verbose:信息展示模式
4. mode:‘auto’,‘min’,‘max’之一,在min模式下,如果檢測值停止下降則中止訓練。在max模式下,當檢測值不再上升則停止訓練。

2.3 LearningRateSchedule

學習率動態調整

keras.callbacks.ReduceLROnPlateau(
    monitor='val_loss', 
    factor=0.1, 
    patience=10, 
    verbose=0, 
    mode='auto', 
    epsilon=0.0001, 
    cooldown=0, 
    min_lr=0
)

其中:
1. monitor:被監測的量
2. factor:每次減少學習率的因子,學習率將以lr = lr*factor的形式被減少
3. patience:當patience個epoch過去而模型性能不提升時,學習率減少的動作會被觸發
4. mode:‘auto’,‘min’,‘max’之一,在min模式下,如果檢測值觸發學習率減少。在max模式下,當檢測值不再上升則觸發學習率減少。
5. epsilon:閾值,用來確定是否進入檢測值的“平原區”
6. cooldown:學習率減少后,會經過cooldown個epoch才重新進行正常操作
7. min_lr:學習率的下限

當學習停滯時,減少2倍或10倍的學習率常常能獲得較好的效果


自定義動態調整學習率:

def step_decay(epoch):
    initial_lrate = 0.01
    drop = 0.5
    epochs_drop = 10.0
    lrate = initial_lrate * math.pow(drop,math.floor((1+epoch)/epochs_drop))
    return lrate
lrate = LearningRateScheduler(step_decay)
sgd = SGD(lr=0.0, momentum=0.9, decay=0.0, nesterov=False)
model.fit(train_set_x, train_set_y, validation_split=0.1, nb_epoch=200, batch_size=256, callbacks=[lrate])

具體可以參考這篇文章Using Learning Rate Schedules for Deep Learning Models in Python with Keras

2.4 記錄每一次epoch的訓練/驗證損失/准確度?

Model.fit函數會返回一個 History 回調,該回調有一個屬性history包含一個封裝有連續損失/准確的lists。代碼如下:

hist = model.fit(X, y,validation_split=0.2)  
print(hist.history)

Keras輸出的loss,val這些值如何保存到文本中去
Keras中的fit函數會返回一個History對象,它的History.history屬性會把之前的那些值全保存在里面,如果有驗證集的話,也包含了驗證集的這些指標變化情況,具體寫法

hist=model.fit(train_set_x,train_set_y,batch_size=256,shuffle=True,nb_epoch=nb_epoch,validation_split=0.1)
with open('log_sgd_big_32.txt','w') as f:
    f.write(str(hist.history))
2.5 TensorBoard

from keras.callbacks import TensorBoard

tensorboard = TensorBoard(log_dir='./logs', histogram_freq=0,
                          write_graph=True, write_images=False)
# define model
model.fit(X_train, Y_train,
          batch_size=batch_size,
          epochs=nb_epoch,
          validation_data=(X_test, Y_test),
          shuffle=True,
          callbacks=[tensorboard])

使用tensorboard時,在終端輸入

tensorboard --logdir path_to_current_dir
2.5 多個回調函數用逗號隔開
from keras.callbacks import TensorBoard
from keras.callbacks import EarlyStopping
from keras.callbacks import ModelCheckpoint
from keras.callbacks import ReduceLROnPlateau


# callbacks:
tb = TensorBoard(log_dir='./logs',  # log 目錄
                 histogram_freq=1,  # 按照何等頻率(epoch)來計算直方圖,0為不計算
                 batch_size=32,     # 用多大量的數據計算直方圖
                 write_graph=True,  # 是否存儲網絡結構圖
                 write_grads=False, # 是否可視化梯度直方圖
                 write_images=False,# 是否可視化參數
                 embeddings_freq=0,
                 embeddings_layer_names=None,
                 embeddings_metadata=None)

es=EarlyStopping(monitor='val_loss', patience=20, verbose=0)

mc=ModelCheckpoint(
    './logs/weight.hdf5',
    monitor='val_loss',
    verbose=0,
    save_best_only=True,
    save_weights_only=False,
    mode='auto',
    period=1
)

rp=ReduceLROnPlateau(
    monitor='val_loss',
    factor=0.1,
    patience=20,
    verbose=0,
    mode='auto',
    epsilon=0.0001,
    cooldown=0,
    min_lr=0
)

callbacks = [es,tb,mc,rp]

# start to train out model
bs = 100
ne = 1000
hist = model.fit(data, labels_cat,batch_size=bs,epochs=ne,
                      verbose=2,validation_split=0.25,callbacks=callbacks)

print("train process done!!")
 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM