版權聲明:本文為博主原創文章,歡迎轉載,並請注明出處。聯系方式:460356155@qq.com
一、下載數據集
百度搜索“kaggle 貓狗數據集”,可找到網盤共享的貓狗數據集,有815M。
二、准備數據集
整個數據集有25000張圖,貓狗各12500,從中選取1000、500、200分別作為訓練、驗證、測試集。
import os import random import shutil # 隨機得到樣本子集 def get_sub_sample(sample_path, target_path, train, valid, test, file_name_format=None, class_name=None, class_num=None): """ sample_path: 樣本全集目錄 target_path: 樣本子集目錄 train, valid, test: 隨機選取訓練、驗證、測試樣本數 file_name_format:文件名格式過濾 class_name:樣本類型名 class_num: 樣本數 """ # 得到樣本全集目錄下的所有文件,不遍歷子目錄 all_files = [f for f in os.listdir(sample_path) if os.path.isfile(os.path.join(sample_path, f))] total = len(all_files) if file_name_format: # 針對一個目錄放多種類型情況 num_per_class = int(total / class_num) fnames = [file_name_format.format(i) for i in range(num_per_class)] else: fnames = all_files # 打亂順序 random.shuffle(fnames) os.makedirs(os.path.join(target_path, 'train', class_name)) os.makedirs(os.path.join(target_path, 'valid', class_name)) os.makedirs(os.path.join(target_path, 'test', class_name)) for i in range(train): src = os.path.join(sample_path, fnames[i]) dst = os.path.join(target_path, 'train', class_name, fnames[i]) shutil.copyfile(src, dst) for i in range(train, train + valid): src = os.path.join(sample_path, fnames[i]) dst = os.path.join(target_path, 'valid', class_name, fnames[i]) shutil.copyfile(src, dst) for i in range(train + valid, train + valid + test): src = os.path.join(sample_path, fnames[i]) dst = os.path.join(target_path, 'test', class_name, fnames[i]) shutil.copyfile(src, dst) src_path = r'D:\BaiduNetdiskDownload\train' dst_path = r'D:\BaiduNetdiskDownload\small' train_dir = os.path.join(dst_path, 'train') validation_dir = os.path.join(dst_path, 'valid') class_name = ['cat', 'dog'] if os.path.exists(dst_path): shutil.rmtree(dst_path) os.makedirs(dst_path) for cls in class_name: get_sub_sample(src_path, dst_path, 1000, 500, 200, file_name_format='%s.{}.jpg' % (cls), class_name=cls, class_num=2)
三、模型建立
from keras import layers from keras import models model = models.Sequential() # 輸出圖片尺寸:150-3+1=148*148,參數數量:32*3*3*3+32=896 model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(150, 150, 3))) model.add(layers.MaxPooling2D((2, 2))) # 輸出圖片尺寸:148/2=74*74 # 輸出圖片尺寸:74-3+1=72*72,參數數量:64*3*3*32+64=18496 model.add(layers.Conv2D(64, (3, 3), activation='relu')) model.add(layers.MaxPooling2D((2, 2))) # 輸出圖片尺寸:72/2=36*36 # 輸出圖片尺寸:36-3+1=34*34,參數數量:128*3*3*64+128=73856 model.add(layers.Conv2D(128, (3, 3), activation='relu')) model.add(layers.MaxPooling2D((2, 2))) # 輸出圖片尺寸:34/2=17*17 # 輸出圖片尺寸:17-3+1=15*15,參數數量:128*3*3*128+128=147584 model.add(layers.Conv2D(128, (3, 3), activation='relu')) model.add(layers.MaxPooling2D((2, 2))) # 輸出圖片尺寸:15/2=7*7 # 多維轉為一維:7*7*128=6272 model.add(layers.Flatten()) # 參數數量:6272*512+512=3211776 model.add(layers.Dense(512, activation='relu')) # 參數數量:512*1+1=513 model.add(layers.Dense(1, activation='sigmoid'))
model.summary()
四、模型compile
from keras import optimizers # 二分類用binary_crossentropy model.compile(loss='binary_crossentropy', optimizer=optimizers.RMSprop(lr=1e-4), metrics=['acc'])
五、建立訓練和驗證數據
from keras.preprocessing.image import ImageDataGenerator # 歸一化 train_datagen = ImageDataGenerator(rescale=1. / 255) test_datagen = ImageDataGenerator(rescale=1. / 255) train_generator = train_datagen.flow_from_directory( train_dir, # 輸入訓練圖像尺寸 target_size=(150, 150), batch_size=20, # 二分類 class_mode='binary') validation_generator = test_datagen.flow_from_directory( validation_dir, target_size=(150, 150), batch_size=20, class_mode='binary')
六、訓練
history = model.fit_generator( train_generator, # 2000張圖 / 20 batch size steps_per_epoch=100, epochs=30, validation_data=validation_generator, # 1000張圖 / 20 batch size validation_steps=50)
WARNING:tensorflow:From d:\program files\python37\lib\site-packages\tensorflow\python\ops\math_ops.py:3066: to_int32 (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version. Instructions for updating: Use tf.cast instead. Epoch 1/30 100/100 [==============================] - 41s 409ms/step - loss: 0.6903 - acc: 0.5255 - val_loss: 0.6730 - val_acc: 0.6070 Epoch 2/30 100/100 [==============================] - 41s 406ms/step - loss: 0.6599 - acc: 0.6070 - val_loss: 0.6350 - val_acc: 0.6510 Epoch 3/30 100/100 [==============================] - 41s 408ms/step - loss: 0.6135 - acc: 0.6710 - val_loss: 0.6223 - val_acc: 0.6400 Epoch 4/30 100/100 [==============================] - 41s 410ms/step - loss: 0.5816 - acc: 0.6960 - val_loss: 0.5798 - val_acc: 0.6950 Epoch 5/30 100/100 [==============================] - 41s 411ms/step - loss: 0.5582 - acc: 0.7160 - val_loss: 0.5757 - val_acc: 0.6970 Epoch 6/30 100/100 [==============================] - 42s 420ms/step - loss: 0.5278 - acc: 0.7360 - val_loss: 0.5788 - val_acc: 0.6790 Epoch 7/30 100/100 [==============================] - 41s 412ms/step - loss: 0.5096 - acc: 0.7485 - val_loss: 0.5551 - val_acc: 0.7140 Epoch 8/30 100/100 [==============================] - 42s 418ms/step - loss: 0.4809 - acc: 0.7715 - val_loss: 0.5871 - val_acc: 0.6870 Epoch 9/30 100/100 [==============================] - 42s 416ms/step - loss: 0.4645 - acc: 0.7850 - val_loss: 0.5309 - val_acc: 0.7370 Epoch 10/30 100/100 [==============================] - 42s 415ms/step - loss: 0.4348 - acc: 0.7960 - val_loss: 0.5618 - val_acc: 0.7160 Epoch 11/30 100/100 [==============================] - 42s 420ms/step - loss: 0.4133 - acc: 0.8050 - val_loss: 0.5714 - val_acc: 0.7210 Epoch 12/30 100/100 [==============================] - 41s 409ms/step - loss: 0.3847 - acc: 0.8215 - val_loss: 0.5937 - val_acc: 0.7030 Epoch 13/30 100/100 [==============================] - 41s 413ms/step - loss: 0.3523 - acc: 0.8465 - val_loss: 0.6225 - val_acc: 0.7030 Epoch 14/30 100/100 [==============================] - 42s 416ms/step - loss: 0.3339 - acc: 0.8535 - val_loss: 0.5339 - val_acc: 0.7500 Epoch 15/30 100/100 [==============================] - 43s 428ms/step - loss: 0.3013 - acc: 0.8650 - val_loss: 0.5404 - val_acc: 0.7520 Epoch 16/30 100/100 [==============================] - 42s 417ms/step - loss: 0.2736 - acc: 0.8885 - val_loss: 0.5885 - val_acc: 0.7380 Epoch 17/30 100/100 [==============================] - 41s 415ms/step - loss: 0.2562 - acc: 0.8995 - val_loss: 0.5636 - val_acc: 0.7420 Epoch 18/30 100/100 [==============================] - 41s 415ms/step - loss: 0.2294 - acc: 0.9115 - val_loss: 0.5722 - val_acc: 0.7490 Epoch 19/30 100/100 [==============================] - 42s 415ms/step - loss: 0.2004 - acc: 0.9210 - val_loss: 0.6201 - val_acc: 0.7390 Epoch 20/30 100/100 [==============================] - 41s 413ms/step - loss: 0.1812 - acc: 0.9315 - val_loss: 0.6323 - val_acc: 0.7390 Epoch 21/30 100/100 [==============================] - 42s 423ms/step - loss: 0.1551 - acc: 0.9495 - val_loss: 0.5949 - val_acc: 0.7530 Epoch 22/30 100/100 [==============================] - 50s 500ms/step - loss: 0.1438 - acc: 0.9505 - val_loss: 0.6145 - val_acc: 0.7500 Epoch 23/30 100/100 [==============================] - 45s 447ms/step - loss: 0.1131 - acc: 0.9660 - val_loss: 0.7587 - val_acc: 0.7340 Epoch 24/30 100/100 [==============================] - 42s 415ms/step - loss: 0.1012 - acc: 0.9650 - val_loss: 0.7000 - val_acc: 0.7500 Epoch 25/30 100/100 [==============================] - 42s 425ms/step - loss: 0.0852 - acc: 0.9765 - val_loss: 0.7501 - val_acc: 0.7400 Epoch 26/30 100/100 [==============================] - 43s 427ms/step - loss: 0.0730 - acc: 0.9785 - val_loss: 0.7945 - val_acc: 0.7500 Epoch 27/30 100/100 [==============================] - 41s 410ms/step - loss: 0.0643 - acc: 0.9825 - val_loss: 0.7769 - val_acc: 0.7480 Epoch 28/30 100/100 [==============================] - 41s 415ms/step - loss: 0.0544 - acc: 0.9860 - val_loss: 0.8410 - val_acc: 0.7530 Epoch 29/30 100/100 [==============================] - 41s 410ms/step - loss: 0.0435 - acc: 0.9910 - val_loss: 0.8678 - val_acc: 0.7670 Epoch 30/30 100/100 [==============================] - 41s 411ms/step - loss: 0.0370 - acc: 0.9920 - val_loss: 0.8941 - val_acc: 0.7640
在第9次迭代時,驗證損失達到最小,驗證精度在74%左右,隨着迭代次數增加,出現了過擬合。顯示訓練曲線:
% matplotlib inline import matplotlib.pyplot as plt acc = history.history['acc'] val_acc = history.history['val_acc'] loss = history.history['loss'] val_loss = history.history['val_loss'] epochs = range(len(acc)) plt.plot(epochs, acc, 'bo', label='Training acc') plt.plot(epochs, val_acc, 'b', label='Validation acc') plt.title('Training and validation accuracy') plt.legend() plt.figure() plt.plot(epochs, loss, 'bo', label='Training loss') plt.plot(epochs, val_loss, 'b', label='Validation loss') plt.title('Training and validation loss') plt.legend() plt.show()
七、保存模型
model.save('cats_and_dogs_small_1.h5')