keras 學習筆記(二) ——— data_generator


data_generator

每次輸出一個batch,基於keras.utils.Sequence

Base object for fitting to a sequence of data, such as a dataset.

Every Sequence must implement the __getitem__ and the __len__ methods. If you want to modify your dataset between epochs you may implement on_epoch_end. The method __getitem__ should return a complete batch.

Notes

Sequence are a safer way to do multiprocessing. This structure guarantees that the network will only train once on each sample per epoch which is not the case with generators.

   Sequence example: https://keras.io/utils/#sequence

#!/usr/bin/env python
# coding: utf-8



from keras.utils import Sequence
import numpy as np
from keras.preprocessing import image
from skimage.io import imread

class My_Custom_Generator(Sequence) :
    def __init__(self, image_filenames, labels, batch_size) :
        self.image_filenames = image_filenames
        self.labels = labels
        self.batch_size = batch_size
    def __len__(self) :
        return (np.ceil(len(self.image_filenames) / float(self.batch_size))).astype(np.int)

    def __getitem__(self, idx) :
        batch_y = self.labels[idx * self.batch_size : (idx+1) * self.batch_size]
        batch_x = self.image_filenames[idx * self.batch_size : (idx+1) * self.batch_size]
        batch_seq  = []  #batch_seq
        for x in batch_x:  #len(x) =16
            seq_img = []
            for img in x: #len(item) =25
                seq_img.append(image.img_to_array(imread(img)))
            seq_x = np.array([seq_img])
            batch_seq.append(seq_img)
        batch_seq_list = np.array(batch_seq)
        return batch_seq_list, np.array(batch_y)

 

兩種將數據輸出為numpy.array的方法

 

通過list轉為numpy.array

速度快,list轉array過程需要注意數據維度變化

''' list
batch_x =X_train_filenames[idx * batch_size : (idx+1) * batch_size]
batch_seq  = []  #batch_seq
for x in batch_x:  #len(x) =16
    seq_img = []
    for img in x: #len(item) =25
        seq_img.append(image.img_to_array(imread(img)))
    seq_x = np.array([seq_img])
    batch_seq.append(seq_img)
batch_seq_list = np.array(batch_seq)
'''

 

利用np.empty

速度慢,開始前確定batch維度即可

'''numpy
batch_x =X_train_filenames[idx * batch_size : (idx+1) * batch_size]
batch_seq  = np.empty((0,25,224,224,3),float)
for x in batch_x:  #len(x) =16
    seq_batch = np.empty((0,224,224,3),float)
    for item in x: #len(item) =25
        seq_batch = np.append(seq_batch, np.expand_dims(image.img_to_array(imread(item)), axis=0), axis = 0) 
    batch_seq2 = np.append(batch_seq, np.expand_dims((seq_batch), axis=0), axis = 0)
'''

  

 

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM