將圖片集使用迭代器,分batch輸入到卷積神經網絡中


問題來源:寫了一個神經網絡,需要用的測試集是本地圖片。

第一次嘗試解決:將本地圖片讀取,亂序,存成npz形式的文件。在第二次使用時,load這個npz文件。但這個方法針對圖片量比較大的情況沒辦法應對,圖片大小超過電腦內存。

第二次嘗試解決:嘗試將文件分批存儲成npz形式,一次讀取數據進行訓練,但是在keras平台下難以訓練。

第三次嘗試解決:采用迭代器分批讀取數據,使用fit_generator分批訓練網絡。這是需要注意的問題是,如果連續讀入的都是一類文件,就會導致模型偏移。因為前半部分都是數據1,后半部分都是數據2.

解決思路如下:定義一個迭代器,在迭代器中,依次循環從數據文件夾中讀取數據文件,這樣可以保證一個batch中的每種類別的數據的個數是相同的。

  1 import os
  2 import cv2
  3 import keras
  4 import numpy as np
  5 def Generator(path, batch_size, data_num):
  6     i = 1
  7     data = []
  8     label = []
  9     while True:
 10         i = 1
 11         while i <= data_num:
 12             for j in range(1,3):
 13                 f = os.path.join('dataset', path, '%d' %j, '%d.jpg' %i)
 14                 im = cv2.imread(f)
 15                 im = cv2.resize(im, (227, 227))
 16                 data.append(im)
 17                 label.append(j-1)
 18             if(len(label) == batch_size):
 19                 data = np.array(data, dtype='float32')
 20                 label = keras.utils.to_categorical(label, 2)
 21                 yield data, label
 22                 data = []
 23                 label = []
 24             i += 1
 25 
 26 #*******************AlexNet_begin**************************
 27 from keras.models import Sequential
 28 from keras.layers.convolutional import Conv2D, MaxPooling2D
 29 from keras.layers import Dense, Flatten, Dropout, BatchNormalization
 30 
 31 batch_size = 4
 32 num_classes = 2
 33 epochs = 8
 34 
 35 model = Sequential()
 36 #第一層
 37 model.add(Conv2D(96, (11, 11),      #卷積核
 38                  strides=4,         #步長
 39                  input_shape=(227, 227, 3),
 40                  padding='valid',   #無填充
 41                  activation='relu'))
 42 model.add(MaxPooling2D(pool_size=(3, 3), strides=2 ))
 43 model.add(BatchNormalization())
 44 #第二層
 45 model.add(Conv2D(
 46     kernel_size=(27, 27),
 47     filters= 256,
 48     strides= 1,
 49     padding='same',
 50     activation='relu'
 51 ))
 52 model.add(MaxPooling2D(pool_size=(3, 3), strides=2 ))
 53 #第三層
 54 model.add(Conv2D(
 55     kernel_size=(3, 3),
 56     filters= 384,
 57     strides= 1,
 58     padding= 'same',
 59     activation= 'relu'
 60 ))
 61 #第四層
 62 model.add(Conv2D(
 63     kernel_size= (3, 3),
 64     filters= 384,
 65     strides= 1,
 66     padding= 'same',
 67     activation= 'relu'
 68 ))
 69 #第五層
 70 model.add(Conv2D(
 71     kernel_size= (3, 3),
 72     filters= 256,
 73     strides= 1,
 74     padding= 'same',
 75     activation= 'relu'
 76 ))
 77 model.add(MaxPooling2D(pool_size=(3, 3), strides=2 ))
 78 #第六層
 79 model.add(Flatten())
 80 model.add(Dense(128, activation='relu'))
 81 model.add(Dropout(0.5))
 82 #第七層
 83 model.add(Dense(128, activation='relu'))
 84 model.add(Dropout(0.5))
 85 #第八層
 86 model.add(Dense(2, activation='softmax'))
 87 #打印模型
 88 model.summary()
 89 
 90 model.compile('sgd', loss='categorical_crossentropy', metrics=['accuracy'])
 91 #model.fit(x_train, y_train, epochs=5, validation_data=[x_test, y_test])
 92 print('fit......')
 93 model.fit_generator(Generator('train', 4, 100),
 94                         epochs=epochs,
 95                         steps_per_epoch=50,
 96                         workers=1)
 97 model.save('model.h5')
 98 # Evaluate the model with the metrics we defined earlier
 99 print('evaluate......')
100 loss1, accuracy1 = model.evaluate_generator(Generator('test', 4, 20), steps = 10)
101 print('test......')
102 loss2, accuracy2 = model.evaluate_generator(Generator('test2', 4, 12), steps = 6)
103 print('test1 loss: ', loss1)
104 print('test1 accuracy: ', accuracy1)
105 print('test2 loss: ', loss2)
106 print('test2 accuracy: ', accuracy2)
107 #*******************AlexNet_end****************************

數據文件結構如圖所示:

 

參考資料鏈接:

1、https://github.com/keras-team/keras/issues/7729

2、https://blog.csdn.net/learning_tortosie/article/details/85243310

3、https://keras-cn.readthedocs.io/en/latest/models/model/

4、https://blog.csdn.net/shahuzi/article/details/81210557

5、https://blog.csdn.net/yideqianfenzhiyi/article/details/79197570

6、https://zhuanlan.zhihu.com/p/31558973


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM