前言
是的,除了水報錯文,我也來寫點其他的。本文主要介紹Keras中以下三個函數的用法:
fit()
fit_generator()
train_on_batch()
當然,與上述三個函數相似的evaluate、predict、test_on_batch、predict_on_batch、evaluate_generator和predict_generator等就不詳細說了,舉一反三嘛。
環境
本文的代碼是在以下環境下進行測試的:
Windows 10
Python 3.6
TensorFlow 2.0 Alpha
異同
大家用Keras也就圖個簡單快捷,但是在享受簡單快捷的時候,也常常需要些定制化需求,除了model.fit(),有時候model.fit_generator()和model.train_on_batch()也很重要。那么,這三個函數有什么異同呢?Adrian Rosebrock [1] 有如下總結:
當你使用.fit()函數時,意味着如下兩個假設:
訓練數據可以 完整地 放入到內存(RAM)里
數據已經不需要再進行任何處理了
這兩個原因解釋的非常好,之前我運行程序的時候,由於數據集太大(實際中的數據集顯然不會都像 TensorFlow 官方教程里經常使用的 MNIST 數據集那樣小),一次性加載訓練數據到fit()函數里根本行不通:
history = model.fit(train_data, train_label) // Bomb!!!
1
於是我想,能不能先加載一個batch訓練,然后再加載一個batch,如此往復。於是我就注意到了fit_generator()函數。什么時候該使用fit_generator函數呢?Adrian Rosebrock 的總結道:
內存不足以一次性加載整個訓練數據的時候
需要一些數據預處理(例如旋轉和平移圖片、增加噪音、擴大數據集等操作)
在生成batch的時候需要更多的處理
對於我自己來說,除了數據集太大的緣故之外,我需要在生成batch的時候,對輸入數據進行padding,所以fit_generator()就派上了用場。下面介紹如何使用這三種函數。
fit()函數
fit()函數其實沒什么好說的,大家在看TensorFlow教程的時候已經見識過了。此外插一句話,tf.data.Dataset對不規則的序列數據真是不友好。
import tensorflow as tf
model = tf.keras.models.Sequential([
... // 你的模型
])
model.fit(train_x, // 訓練輸入
train_y, // 訓練標簽
epochs=5 // 訓練5輪
)
1
2
3
4
5
6
7
8
9
10
fit_generator()函數
fit_generator()函數就比較重要了,也是本文討論的重點。fit_generator()與fit()的主要區別就在一個generator上。之前,我們把整個訓練數據都輸入到fit()里,我們也不需要考慮batch的細節;現在,我們使用一個generator,每次生成一個batch送給fit_generator()訓練。
def generator(x, y, b_size):
... // 處理函數
model.fit_generator(generator(train_x, train_y, batch_size),
step_per_epochs=np.ceil(len(train_x)/batch_size),
epochs=5
)
1
2
3
4
5
6
7
從上述代碼中,我們發現有兩處不同:
一個我們自定義的generator()函數,作為fit_generator()函數的第一個參數;
fit_generator()函數的step_per_epochs參數
自定義的generator()函數
該函數即是我們數據的生成器,在訓練的時候,fit_generator()函數會不斷地執行generator()函數,獲取一個個的batch。
def generator(x, y, b_size):
"""Generates batch and batch and batch then feed into models.
Args:
x: input data;
y: input labels;
b_size: batch_size.
Yield:
(batch_x, batch_label): batched x and y.
"""
while 1: // 死循環
idx = ...
batch_x = ...
batch_y = ...
... // 任何你想要對這個`batch`中的數據執行的操作
yield (batch_x, batch_y)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
需要注意的是,不要使用return或者exit。
step_per_epochs參數
由於generator()函數的循環沒有終止條件,fit_generator也不知道一個epoch什么時候結束,所以我們需要手動指定step_per_epochs參數,一般的數值即為len(y)//batch_size。如果數據集大小不能整除batch_size,而且你打算使用最后一個batch的數據(該batch比batch_size要小),此時使用np.ceil(len(y)/batch_size)。
keras.utils.Sequence類(2019年6月10日更新)
除了寫generator()函數,我們還可以利用keras.utils.Sequence類來生成batch。先扔代碼:
class Generator(keras.utils.Sequence):
def __init__(self, x, y, b_size):
self.x, self.y = x, y
self.batch_size = b_size
def __len__(self):
return math.ceil(len(self.y)/self.batch_size
def __getitem__(self, idx):
b_x = self.x[idx*self.batch_size:(idx+1)*self.batch_size]
b_y = self.y[idx*self.batch_size:(idx+1)*self.batch_size]
... // 對`batch`的其余操作
return np.array(b_x), np.array(b_y)
def on_epoch_end(self):
"""執行完一個`epoch`之后,還可以做一些其他的事情!"""
...
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
我們首先定義__init__函數,讀取訓練集數據,然后定義__len__函數,返回一個epoch中需要執行的step數(此時在fit_generator()函數中就不需要指定steps_per_epoch參數了),最后定義__getitem__函數,返回一個batch的數據。代碼如下:
train_generator = Generator(train_x, train_y, batch_size)
val_generator = Generator(val_x, val_y, batch_size)
model.fit_generator(generator=train_generator,
epochs=3197747,
validation_data=val_generator
)
1
2
3
4
5
6
7
根據官方 [2] 的說法,使用Sequence類可以保證在多進程的情況下,每個epoch中的樣本只會被訓練一次。總之,使用keras.utils.Sequence也是很方便的啦!
train_on_batch()函數
train_on_batch()函數接受一個batch的輸入和標簽,然后開始反向傳播,更新參數等。大部分情況下你都不需要用到train_on_batch()函數,除非你有着充足的理由去定制化你的模型的訓練流程。
結語
本文到此結束啦!也不知道講清楚沒有,如果有疑問或者有錯誤,還請讀者不吝賜教啦!
Reference
A. Rosebrock. (December 24, 2018). How to use Keras fit and fit_generator (a hands-on tutorial). Retrieved from https://www.pyimagesearch.com/2018/12/24/how-to-use-keras-fit-and-fit_generator-a-hands-on-tutorial/
tf.keras.utils.Sequence. (July 10, 2019). Retrieved from https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/keras/utils/Sequence
---------------------