Keras 使用自己編寫的數據生成器


使用自己編寫的數據生成器,配合keras的fit_generator訓練模型

注意:模型結構要和生成器生成數據的尺寸要對應,txt存的數據路徑一般是有序的,想辦法打亂它

# 以下部分代碼,僅做示意
……
def gen_mine():
    txtpath = './2.txt' # 數據路徑存在txt
    data_train = []
    data_labels = []
    cnt = 0 # 用於批量計數
    for n in open(txtpath):
        img = cv2.imread(n[:-1]) # 最后一個字節是換行符,去掉它
        img_64 = cv2.resize(img,(64,64)) # 輸入到模型前要統一尺寸
        img_rgb = img_64[:,:,::-1] # cv讀的數據是bgr,這里改成標准的rgb
        if n.split('/')[1] == 'file_N': # 由於我是根據文件夾的名字定的標簽,這個看自己的需求
            label = [0,1,0] # 注意要寫成獨熱編碼的形式
        else:
            label = [1,0,0]
        data_train.append(img_rgb)
        data_labels.append(label)
        cnt = cnt + 1
        if cnt == BS:
            cnt = 0 # 初始化
            data_train = np.array(data_train)
            data_labels = np.array(data_labels)
            print(data_train.shape, data_labels.shape)
            yield (data_train, data_labels)
            data_train = [] # 初始化
            data_labels = []
……
model.fit_generator(gen_mine(),steps_per_epoch=steps_per_epoch_, epochs=NUM_EPOCHS, class_weight = 'auto', max_queue_size=1,workers=1)

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM