上一節中,我們利用了預訓練的VGG網絡卷積基,來簡單的提取了圖像的特征,並用這些特征作為輸入,訓練了一個小分類器。
這種方法好處在於簡單粗暴,特征提取部分的卷積基不需要訓練。但缺點在於,一是別人的模型是針對具體的任務訓練的,里面提取到的特征不一定適合自己的任務;二是無法使用圖像增強的方法進行端到端的訓練。
因此,更為常用的一種方法是預訓練模型修剪 + 微調,好處是可以根據自己任務需要,將預訓練的網絡和自定義網絡進行一定的融合;此外還可以使用圖像增強的方式進行端到端的訓練。仍然以VGG16為例,過程為:
- 在已經訓練好的基網絡(base network)上添加自定義網絡;
- 凍結基網絡,訓練自定義網絡;
- 解凍部分基網絡,聯合訓練解凍層和自定義網絡。
注意在聯合訓練解凍層和自定義網絡之前,通常要先訓練自定義網絡,否則,隨機初始化的自定義網絡權重會將大誤差信號傳到解凍層,破壞解凍層以前學到的表示,使得訓練成本增大。
第一步:對預訓練模型進行修改
##################第一步:在已經訓練好的卷積基上添加自定義網絡######################
import numpy as np
from keras.applications.vgg16 import VGG16
from keras.models import Sequential
from keras.layers import Dense, Dropout, Flatten
#搭建模型
conv_base = VGG16(include_top=False, input_shape=(150,150,3)) #模型也可以看作一個層
model = Sequential()
model.add(conv_base)
model.add(Flatten())
model.add(Dense(256, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(1, activation='sigmoid'))
#model.summary()
第二步:凍結卷積基,訓練自定義網絡
######################第二步:凍結卷積基,訓練自定義網絡##########################
#凍結卷積基,確保結果符合預期。或者用assert len(model.trainable_weights) == 30來驗證
print("凍結之前可訓練的張量個數:", len(model.trainable_weights)) #結果為30
conv_base.trainable = False
print("凍結之后可訓練的張量個數:", len(model.trainable_weights)) #結果為4
#注:只有后兩層Dense可以訓練,每層一個權重張量和一個偏置張量,所以有4個
#利用圖像生成器進行圖像增強
from keras.preprocessing.image import ImageDataGenerator
train_datagen = ImageDataGenerator(rescale=1./255,
rotation_range=40,
width_shift_range=0.2,
height_shift_range=0.2,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True,
fill_mode='nearest')
test_datagen = ImageDataGenerator(rescale=1./255) #驗證、測試的圖像生成器不能用圖像增強
train_dir = r'D:\KaggleDatasets\MyDatasets\dogs-vs-cats-small\train'
validation_dir = r'D:\KaggleDatasets\MyDatasets\dogs-vs-cats-small\validation'
train_generator = train_datagen.flow_from_directory(train_dir,
target_size=(150,150),
batch_size=20,
class_mode='binary')
validation_generator = test_datagen.flow_from_directory(validation_dir,
target_size=(150,150),
batch_size=20,
class_mode='binary')
#模型編譯和訓練,注意修改trainable屬性之后需要重新編譯,否則修改無效
from keras import optimizers
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
H = model.fit_generator(train_generator,
steps_per_epoch=2000/20,
epochs=30,
validation_data=validation_generator,
validation_steps=1000/20)
訓練30個epoch之后,結果如圖所示。(結果可視化代碼見上一節)

第三步:解凍部分卷積基(第5個block),聯合訓練
通常keras的凍結和解凍操作用的是模型或層的trainable屬性。需要注意三點:
- model.trainable是全局屬性,layer.trainable是層的屬性,單獨定義層的這一屬性后全局屬性即失效;
- 定義這一屬性后,模型需要重新編譯才能生效;
- conv_base是一個模型,但它在總模型model中是作為一個層的實例,因此遍歷model.layers時會把conv_base作為一個層,如果需要深入conv_base內部各層進行操作,需要遍歷conv_base.layers。
為了確保trainable屬性符合預期,通常會確認一下,下面一些代碼可能會有用。(這段主要是便於理解,跑代碼時可選擇性忽略這段。)
#可視化各層序號及名稱
for i, layer in enumerate(model.layers):
print(i, layer.name)
for i, layer in enumerate(conv_base.layers):
print(i, layer.name)
#由於之前操作錯誤,導致模型全部層都被凍結,所以這個模塊先把所有層解凍
for layer in conv_base.layers: #先解凍卷積基中所有層的張量
layer.trainable = True
for layer in model.layers: #解凍model中所有層張量
layer.trainable = True
#查看各層的trainable屬性
for layer in model.layers:
print(layer.name, layer.trainable)
for layer in conv_base.layers:
print(layer.name, layer.trainable)
#model.trainable = True #注意:設定單獨層的trainable屬性后,全局trainable屬性無效
print(len(conv_base.trainable_weights)) #26
print(len(model.trainable_weights)) #30
經過第二步之后,卷積基被凍結,后兩層Dense可訓練。接下來正式開始第三步,解凍第5個block,聯合訓練解凍層和自定義網絡。
######################第三步:解凍部分卷積基,聯合訓練##########################
#凍結VGG16中前四個block,解凍第五個block
flag = False #標記是否到達第五個block
for layer in conv_base.layers: #注意不是遍歷model.layers
if layer.name == 'block5_conv1': #若到達第五個block,則標記之
flag = True
if flag == False: #若標記為False,則凍結,否則設置為可訓練
layer.trainable = False
else:
layer.trainable = True
print(len(model.trainable_weights)) #應為10
#重新編譯並訓練。血淚教訓,一定要重新編譯,不然trainable屬性就白忙活了!
from keras import optimizers
#注:吐血,官網文檔參數learning_rate,這里竟然不認,只能用lr
model.compile(loss='binary_crossentropy',
optimizer=optimizers.Adam(lr=1e-5), metrics=['accuracy'])
H2 = model.fit_generator(train_generator,
steps_per_epoch=2000/20,
epochs=100,
validation_data=validation_generator,
validation_steps=1000/20)
經過100個epoch之后,結果如下。可以看出驗證准確率被提高到94%左右。

Reference:
書籍:Python深度學習
