實戰 遷移學習 VGG19、ResNet50、InceptionV3 實踐 貓狗大戰 問題


實戰 遷移學習 VGG19、ResNet50、InceptionV3 實踐 貓狗大戰 問題

一、實踐流程

1、數據預處理

主要是對訓練數據進行隨機偏移、轉動等變換圖像處理,這樣可以盡可能讓訓練數據多樣化

另外處理數據方式采用分批無序讀取的形式,避免了數據按目錄排序訓練

 

  1.  
    #數據准備
  2.  
    def DataGen(self, dir_path, img_row, img_col, batch_size, is_train):
  3.  
    if is_train:
  4.  
    datagen = ImageDataGenerator(rescale= 1./255,
  5.  
    zoom_range= 0.25, rotation_range=15.,
  6.  
    channel_shift_range= 25., width_shift_range=0.02, height_shift_range=0.02,
  7.  
    horizontal_flip= True, fill_mode='constant')
  8.  
    else:
  9.  
    datagen = ImageDataGenerator(rescale= 1./255)
  10.  
     
  11.  
    generator = datagen.flow_from_directory(
  12.  
    dir_path, target_size=(img_row, img_col),
  13.  
    batch_size=batch_size,
  14.  
    shuffle=is_train)
  15.  
     
  16.  
    return generator
2、載入現有模型

 

這個部分是核心工作,目的是使用ImageNet訓練出的權重來做我們的特征提取器,注意這里后面的分類層去掉

 

  1.  
    base_model = InceptionV3(weights= 'imagenet', include_top=False, pooling=None,
  2.  
    input_shape=(img_rows, img_cols, color),
  3.  
    classes=nb_classes)

然后是凍結這些層,因為是訓練好的

 

  1.  
    for layer in base_model.layers:
  2.  
    layer.trainable = False
而分類部分,需要我們根據現有需求來新定義的,這里可以根據實際情況自己進行調整,比如這樣
  1.  
    x = base_model.output
  2.  
    # 添加自己的全鏈接分類層
  3.  
    x = GlobalAveragePooling2D()(x)
  4.  
    x = Dense( 1024, activation='relu')(x)
  5.  
    predictions = Dense(nb_classes, activation= 'softmax')(x)
或者

 

  1.  
    x = base_model.output
  2.  
    #添加自己的全鏈接分類層
  3.  
    x = Flatten()(x)
  4.  
    predictions = Dense(nb_classes, activation= 'softmax')(x)
3、訓練模型

這里我們用fit_generator函數,它可以避免了一次性加載大量的數據,並且生成器與模型將並行執行以提高效率。比如可以在CPU上進行實時的數據提升,同時在GPU上進行模型訓練

 

  1.  
    history_ft = model.fit_generator(
  2.  
    train_generator,
  3.  
    steps_per_epoch=steps_per_epoch,
  4.  
    epochs=epochs,
  5.  
    validation_data=validation_generator,
  6.  
    validation_steps=validation_steps)

二、貓狗大戰數據集

 

訓練數據540M,測試數據270M,大家可以去官網下載

https://www.kaggle.com/c/dogs-vs-cats-redux-kernels-edition/data

下載后把數據分成dog和cat兩個目錄來存放

三、訓練

訓練的時候會自動去下權值,比如vgg19_weights_tf_dim_ordering_tf_kernels_notop.h5,但是如果我們已經下載好了的話,可以改源代碼,讓他直接讀取我們的下載好的權值,比如在resnet50.py中

1、VGG19

vgg19的深度有26層,參數達到了549M,原模型最后有3個全連接層做分類器所以我還是加了一個1024的全連接層,訓練10輪的情況達到了89%

2、ResNet50

ResNet50的深度達到了168層,但是參數只有99M,分類模型我就簡單點,一層直接分類,訓練10輪的達到了96%的准確率

3、inception_v3

InceptionV3的深度159層,參數92M,訓練10輪的結果

這是一層直接分類的結果

這是加了一個512全連接的,大家可以隨意調整測試

 

四、完整的代碼

 

  1.  
    # -*- coding: utf-8 -*-
  2.  
    import os
  3.  
    from keras.utils import plot_model
  4.  
    from keras.applications.resnet50 import ResNet50
  5.  
    from keras.applications.vgg19 import VGG19
  6.  
    from keras.applications.inception_v3 import InceptionV3
  7.  
    from keras.layers import Dense,Flatten,GlobalAveragePooling2D
  8.  
    from keras.models import Model,load_model
  9.  
    from keras.optimizers import SGD
  10.  
    from keras.preprocessing.image import ImageDataGenerator
  11.  
    import matplotlib.pyplot as plt
  12.  
     
  13.  
    class PowerTransferMode:
  14.  
    #數據准備
  15.  
    def DataGen(self, dir_path, img_row, img_col, batch_size, is_train):
  16.  
    if is_train:
  17.  
    datagen = ImageDataGenerator(rescale= 1./255,
  18.  
    zoom_range= 0.25, rotation_range=15.,
  19.  
    channel_shift_range= 25., width_shift_range=0.02, height_shift_range=0.02,
  20.  
    horizontal_flip= True, fill_mode='constant')
  21.  
    else:
  22.  
    datagen = ImageDataGenerator(rescale= 1./255)
  23.  
     
  24.  
    generator = datagen.flow_from_directory(
  25.  
    dir_path, target_size=(img_row, img_col),
  26.  
    batch_size=batch_size,
  27.  
    #class_mode='binary',
  28.  
    shuffle=is_train)
  29.  
     
  30.  
    return generator
  31.  
     
  32.  
    #ResNet模型
  33.  
    def ResNet50_model(self, lr=0.005, decay=1e-6, momentum=0.9, nb_classes=2, img_rows=197, img_cols=197, RGB=True, is_plot_model=False):
  34.  
    color = 3 if RGB else 1
  35.  
    base_model = ResNet50(weights= 'imagenet', include_top=False, pooling=None, input_shape=(img_rows, img_cols, color),
  36.  
    classes=nb_classes)
  37.  
     
  38.  
    #凍結base_model所有層,這樣就可以正確獲得bottleneck特征
  39.  
    for layer in base_model.layers:
  40.  
    layer.trainable = False
  41.  
     
  42.  
    x = base_model.output
  43.  
    #添加自己的全鏈接分類層
  44.  
    x = Flatten()(x)
  45.  
    #x = GlobalAveragePooling2D()(x)
  46.  
    #x = Dense(1024, activation='relu')(x)
  47.  
    predictions = Dense(nb_classes, activation= 'softmax')(x)
  48.  
     
  49.  
    #訓練模型
  50.  
    model = Model(inputs=base_model.input, outputs=predictions)
  51.  
    sgd = SGD(lr=lr, decay=decay, momentum=momentum, nesterov= True)
  52.  
    model.compile(loss= 'categorical_crossentropy', optimizer=sgd, metrics=['accuracy'])
  53.  
     
  54.  
    #繪制模型
  55.  
    if is_plot_model:
  56.  
    plot_model(model, to_file= 'resnet50_model.png',show_shapes=True)
  57.  
     
  58.  
    return model
  59.  
     
  60.  
     
  61.  
    #VGG模型
  62.  
    def VGG19_model(self, lr=0.005, decay=1e-6, momentum=0.9, nb_classes=2, img_rows=197, img_cols=197, RGB=True, is_plot_model=False):
  63.  
    color = 3 if RGB else 1
  64.  
    base_model = VGG19(weights= 'imagenet', include_top=False, pooling=None, input_shape=(img_rows, img_cols, color),
  65.  
    classes=nb_classes)
  66.  
     
  67.  
    #凍結base_model所有層,這樣就可以正確獲得bottleneck特征
  68.  
    for layer in base_model.layers:
  69.  
    layer.trainable = False
  70.  
     
  71.  
    x = base_model.output
  72.  
    #添加自己的全鏈接分類層
  73.  
    x = GlobalAveragePooling2D()(x)
  74.  
    x = Dense( 1024, activation='relu')(x)
  75.  
    predictions = Dense(nb_classes, activation= 'softmax')(x)
  76.  
     
  77.  
    #訓練模型
  78.  
    model = Model(inputs=base_model.input, outputs=predictions)
  79.  
    sgd = SGD(lr=lr, decay=decay, momentum=momentum, nesterov= True)
  80.  
    model.compile(loss= 'categorical_crossentropy', optimizer=sgd, metrics=['accuracy'])
  81.  
     
  82.  
    # 繪圖
  83.  
    if is_plot_model:
  84.  
    plot_model(model, to_file= 'vgg19_model.png',show_shapes=True)
  85.  
     
  86.  
    return model
  87.  
     
  88.  
    # InceptionV3模型
  89.  
    def InceptionV3_model(self, lr=0.005, decay=1e-6, momentum=0.9, nb_classes=2, img_rows=197, img_cols=197, RGB=True,
  90.  
    is_plot_model=False):
  91.  
    color = 3 if RGB else 1
  92.  
    base_model = InceptionV3(weights= 'imagenet', include_top=False, pooling=None,
  93.  
    input_shape=(img_rows, img_cols, color),
  94.  
    classes=nb_classes)
  95.  
     
  96.  
    # 凍結base_model所有層,這樣就可以正確獲得bottleneck特征
  97.  
    for layer in base_model.layers:
  98.  
    layer.trainable = False
  99.  
     
  100.  
    x = base_model.output
  101.  
    # 添加自己的全鏈接分類層
  102.  
    x = GlobalAveragePooling2D()(x)
  103.  
    x = Dense( 1024, activation='relu')(x)
  104.  
    predictions = Dense(nb_classes, activation= 'softmax')(x)
  105.  
     
  106.  
    # 訓練模型
  107.  
    model = Model(inputs=base_model.input, outputs=predictions)
  108.  
    sgd = SGD(lr=lr, decay=decay, momentum=momentum, nesterov= True)
  109.  
    model.compile(loss= 'categorical_crossentropy', optimizer=sgd, metrics=['accuracy'])
  110.  
     
  111.  
    # 繪圖
  112.  
    if is_plot_model:
  113.  
    plot_model(model, to_file= 'inception_v3_model.png', show_shapes=True)
  114.  
     
  115.  
    return model
  116.  
     
  117.  
    #訓練模型
  118.  
    def train_model(self, model, epochs, train_generator, steps_per_epoch, validation_generator, validation_steps, model_url, is_load_model=False):
  119.  
    # 載入模型
  120.  
    if is_load_model and os.path.exists(model_url):
  121.  
    model = load_model(model_url)
  122.  
     
  123.  
    history_ft = model.fit_generator(
  124.  
    train_generator,
  125.  
    steps_per_epoch=steps_per_epoch,
  126.  
    epochs=epochs,
  127.  
    validation_data=validation_generator,
  128.  
    validation_steps=validation_steps)
  129.  
    # 模型保存
  130.  
    model.save(model_url,overwrite= True)
  131.  
    return history_ft
  132.  
     
  133.  
    # 畫圖
  134.  
    def plot_training(self, history):
  135.  
    acc = history.history[ 'acc']
  136.  
    val_acc = history.history[ 'val_acc']
  137.  
    loss = history.history[ 'loss']
  138.  
    val_loss = history.history[ 'val_loss']
  139.  
    epochs = range(len(acc))
  140.  
    plt.plot(epochs, acc, 'b-')
  141.  
    plt.plot(epochs, val_acc, 'r')
  142.  
    plt.title( 'Training and validation accuracy')
  143.  
    plt.figure()
  144.  
    plt.plot(epochs, loss, 'b-')
  145.  
    plt.plot(epochs, val_loss, 'r-')
  146.  
    plt.title( 'Training and validation loss')
  147.  
    plt.show()
  148.  
     
  149.  
     
  150.  
    if __name__ == '__main__':
  151.  
    image_size = 197
  152.  
    batch_size = 32
  153.  
     
  154.  
    transfer = PowerTransferMode()
  155.  
     
  156.  
    #得到數據
  157.  
    train_generator = transfer.DataGen( 'data/cat_dog_Dataset/train', image_size, image_size, batch_size, True)
  158.  
    validation_generator = transfer.DataGen( 'data/cat_dog_Dataset/test', image_size, image_size, batch_size, False)
  159.  
     
  160.  
    #VGG19
  161.  
    #model = transfer.VGG19_model(nb_classes=2, img_rows=image_size, img_cols=image_size, is_plot_model=False)
  162.  
    #history_ft = transfer.train_model(model, 10, train_generator, 600, validation_generator, 60, 'vgg19_model_weights.h5', is_load_model=False)
  163.  
     
  164.  
    #ResNet50
  165.  
    model = transfer.ResNet50_model(nb_classes= 2, img_rows=image_size, img_cols=image_size, is_plot_model=False)
  166.  
    history_ft = transfer.train_model(model, 10, train_generator, 600, validation_generator, 60, 'resnet50_model_weights.h5', is_load_model=False)
  167.  
     
  168.  
    #InceptionV3
  169.  
    #model = transfer.InceptionV3_model(nb_classes=2, img_rows=image_size, img_cols=image_size, is_plot_model=True)
  170.  
    #history_ft = transfer.train_model(model, 10, train_generator, 600, validation_generator, 60, 'inception_v3_model_weights.h5', is_load_model=False)
  171.  
     
  172.  
    # 訓練的acc_loss圖
  173.  
    transfer.plot_training(history_ft)


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM