keras中訓練數據的幾種方式對比(fit和fit_generator)


一、train_on_batch

model.train_on_batch(batchX, batchY)

train_on_batch函數接受單批數據,執行反向傳播,然后更新模型參數,該批數據的大小可以是任意的,即,它不需要提供明確的批量大小,屬於精細化控制訓練模型,大部分情況下我們不需要這么精細,99%情況下使用fit_generator訓練方式即可,下面會介紹。

二、fit

model.fit(x_train, y_train, batch_size=32, epochs=10)

fit的方式是一次把訓練數據全部加載到內存中,然后每次批處理batch_size個數據來更新模型參數,epochs就不用多介紹了。這種訓練方式只適合訓練數據量比較小的情況下使用。

三、fit_generator

利用Python的生成器,逐個生成數據的batch並進行訓練,不占用大量內存,同時生成器與模型將並行執行以提高效率。例如,該函數允許我們在CPU上進行實時的數據提升,同時在GPU上進行模型訓練

接口如下:

fit_generator(self, generator, steps_per_epoch, epochs=1, verbose=1, callbacks=None, validation_data=None, validation_steps=None, class_weight=None, max_q_size=10, workers=1, pickle_safe=False, initial_epoch=0)

  • generator:生成器函數
  • steps_per_epoch:整數,當生成器返回steps_per_epoch次數據時,計一個epoch結束,執行下一個epoch。也就是一個epoch下執行多少次batch_size。
  • epochs:整數,控制數據迭代的輪數,到了就結束訓練。
  •  callbacks=None,  list,list中的元素為keras.callbacks.Callback對象,在訓練過程中會調用list中的回調函數

舉例:

def generate_arrays_from_file(path):
            while True:
                with open(path) as f:
                    for line in f:
                        # create numpy arrays of input data
                        # and labels, from each line in the file
                        x1, x2, y = process_line(line)
                        yield ({'input_1': x1, 'input_2': x2}, {'output': y})
 
model.fit_generator(generate_arrays_from_file('./my_folder'),
                            steps_per_epoch=10000, epochs=10)

 

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM