keras fit_generator 並行



雖然已經走在 torch boy 的路上了, 還是把碰到的這個坑給記錄一下

  • 數據量較小時,我們可直接把整個數據集 load 到內存里,用 model.fit() 來擬合模型。
  • 當數據集過大比如幾十個 G 時,內存撐不下,需要用 model.fit_generator 的方式來擬合。

model.fit_generator 一般參數的配置參考官方文檔就好,其中 generator, workers, use_multiprocessing 的使用有一些坑存在。

workers=0, use_multiprocessing=False

此時 generator 用一個普通的 generator去提供數據即可,類似官方提供的這種

def generate_arrays_from_file(path):
    while True:
        with open(path) as f:
            for line in f:
                # create numpy arrays of input data
                # and labels, from each line in the file
                x1, x2, y = process_line(line)
                yield ({'input_1': x1, 'input_2': x2}, {'output': y})

model.fit_generator(generate_arrays_from_file('/my_file.txt'),
                    steps_per_epoch=10000, epochs=10)

workers>0, use_multiprocessing=True

這時依然用一個 generator function 來做 generator在擬合的時候便會報錯如下:

PicklingError: Can't pickle <function generator_queue.<locals>.data_generator_task at

且當 use_multiprocessing=True 時,如果你使用的是 generator function, 代碼會把你的數據copy幾份分給不同的worker去處理,但我們希望的是把一份數據平均分拆成幾份給多個worker去處理。

怎么解決上面兩個問題? keras.utils.Sequence 可以做到

很簡單,繼承 keras.utils.Sequence 這個類,重寫自己的 len(), getitem 即可。

class SequenceData(Sequence):
    def __init__(self, filePaths, batch_size):
        self.filePaths = filePaths[:100].copy()
        self.batch_size = batch_size
        self.Y = self.getY()

    def __len__(self):
        return len(self.Y) // self.batch_size

    def __getitem__(self, index):
        batch_X = np.zeros((self.batch_size,) + IMG_DIMS, dtype='float32')
        batch_Y_ = self.Y[index*self.batch_size: (index+1)*self.batch_size].copy()
        batch_Y_.reset_index(drop=True, inplace=True)
        assert batch_Y_.shape[0] == self.batch_size

        for index, rows in batch_Y_.iterrows():
            try:
                img = _load_img(rows['path'])
                batch_X[index, :, :, :] = img.copy()
                batch_Y_.loc[index, 'valid'] = 1
            except:
                batch_Y_.loc[index, 'valid'] = 0
                traceback.print_exc()
        batch_Y = to_categorical(batch_Y_['label'], classes_num)
        return batch_X, batch_Y

    def __iter__(self):
        for item in (self[i] for i in range(len(self))):
            yield item

    def getY(self):
        Y = pd.DataFrame(self.filePaths, columns=['path'])
        Y['class'] = Y['path'].apply(lambda x: path2class(x))
        Y['label'] = Y['class'].apply(lambda x: class2label[x])
        Y = Y.sample(frac=1).reset_index(drop=True)
        return Y

效果比較

  • 樣本量:1000張圖片
  • 模型: MobileNetV2
  • epochs: 5
  • CPU: 4核,3.4GHz
  • GPU: None

可能數據量過小,並行的效果不是太明顯。

數據讀取方式 workers use_multiprocessing 耗時/s
內存讀取 0 True 1797
keras.utils.Sequence 0 False 1475
keras.utils.Sequence 4 True

參考:


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM