tensorflow finuetuning 例子


最近研究了下如何使用tensorflow進行finetuning,相比於caffe,tensorflow的finetuning麻煩一些,記錄如下:

1.原理

finetuning原理很簡單,利用一個在數據A集上已訓練好的模型作為初始值,改變其部分結構,在另一數據集B上(采用小學習率)訓練的過程叫做finetuning。

一般來講,符合如下情況會采用finetuning

  • 數據集A和B有相關性
  • 數據集A較大
  • 數據集B較小

 

2.關鍵代碼

在數據集A上訓練的時候,和普通的tensorflow訓練過程完全一致。但是在數據集B上進行finetuning時,需要先從之前訓練好的checkpoint中恢復模型參數,這個地方比較關鍵,

需要注意只恢復需要恢復的參數,其他參數不要恢復,否則會因為找不到的聲明而報錯。以mnist為例子,如果我想先訓練一個0-7的8類分類器,網絡結構如下:

conv1-conv2-fc8(其他不帶權重的pooling、softmaxloss層忽略)

然后我想用這個訓練出的模型參數,在0-9的10類分類器上做finetuning,網絡結構如下:

conv1-conv2-fc10

那么在從checkpoint中恢復模型參數時,我只能恢復conv1-conv2,如果連fc8都恢復了,就會因為找不到fc8的定義而報錯

以上描述對應的代碼如下:

1     if tf.train.latest_checkpoint('ckpts') is not None:
2         trainable_vars = tf.trainable_variables()
3         res_vars = [t for t in trainable_vars if t.name.startswith('conv')]
4         saver = tf.train.Saver(var_list=res_vars)
5         saver.restore(sess, tf.train.latest_checkpoint('ckpts'))
6     else:
7         saver = tf.train.Saver()

 

3.demo

利用mnist寫了一個簡單的finetuning例子,大家可以試試,事實證明,利用一個相關的已有模型做finuetuning比從0開始訓練收斂的更快並且收斂到的准確率更高,

點我下載

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM