最近研究了下如何使用tensorflow進行finetuning,相比於caffe,tensorflow的finetuning麻煩一些,記錄如下:
1.原理
finetuning原理很簡單,利用一個在數據A集上已訓練好的模型作為初始值,改變其部分結構,在另一數據集B上(采用小學習率)訓練的過程叫做finetuning。
一般來講,符合如下情況會采用finetuning
- 數據集A和B有相關性
- 數據集A較大
- 數據集B較小
2.關鍵代碼
在數據集A上訓練的時候,和普通的tensorflow訓練過程完全一致。但是在數據集B上進行finetuning時,需要先從之前訓練好的checkpoint中恢復模型參數,這個地方比較關鍵,
需要注意只恢復需要恢復的參數,其他參數不要恢復,否則會因為找不到的聲明而報錯。以mnist為例子,如果我想先訓練一個0-7的8類分類器,網絡結構如下:
conv1-conv2-fc8(其他不帶權重的pooling、softmaxloss層忽略)
然后我想用這個訓練出的模型參數,在0-9的10類分類器上做finetuning,網絡結構如下:
conv1-conv2-fc10
那么在從checkpoint中恢復模型參數時,我只能恢復conv1-conv2,如果連fc8都恢復了,就會因為找不到fc8的定義而報錯
以上描述對應的代碼如下:
1 if tf.train.latest_checkpoint('ckpts') is not None: 2 trainable_vars = tf.trainable_variables() 3 res_vars = [t for t in trainable_vars if t.name.startswith('conv')] 4 saver = tf.train.Saver(var_list=res_vars) 5 saver.restore(sess, tf.train.latest_checkpoint('ckpts')) 6 else: 7 saver = tf.train.Saver()
3.demo
利用mnist寫了一個簡單的finetuning例子,大家可以試試,事實證明,利用一個相關的已有模型做finuetuning比從0開始訓練收斂的更快並且收斂到的准確率更高,