簡介
caffe是一個友好、易於上手的開源深度學習平台,主要用於圖像的相關處理,可以支持CNN等多種深度學習網絡。
基於caffe,開發者可以方便快速地開發簡單的學習網絡,用於分類、定位等任務,也可以用於科研,在其源碼基礎上進行修改,實現自己的算法。
本文的主要目的,是介紹caffe的基本使用方法,希望通過本文,能讓普通的工程師可以使用caffe訓練自己的簡單模型。
本文主要包括以下內容:運行caffe的例子訓練cifar訓練集、使用別人定義好的網絡訓練自己的數據、使用訓練好的模型fine tune自己的數據。
背景知識簡介
深度學習是機器學習的一個分支,主要目標在於通過學習的方法,解決以往普通編程無法解決的問題,例如:圖像識別、文字識別等等。
機器學習里的“學習”,指通過向程序輸入經驗數據,通過若干次“迭代”,不斷改進算法參數,最終能夠獲得“模型”,使用新數據輸入模型,計算得出想要的結果。
例如圖像分類任務中,經驗數據是圖片和對應的文字,訓練出模型后,將新圖片使用模型運算,就可以知道其對應的類別。
以上只是簡單介紹,這里還是建議先學習機器學習、卷積神經網絡的相關基礎知識。
安裝
這一部分網上有不少教程,這里就略掉,另外,我是用docker的鏡像直接安裝的,網上可以直接搜到帶caffe的docker鏡像。好處是省去安裝環境的時間,缺點是后面設置文件會麻煩一些,建議從長計議還是直接安裝在電腦上。
訓練cifar訓練集
cifar是一個常見的圖像分類訓練集,包括上萬張圖片及20個分類,caffe提供了一個網絡用於分類cifar數據集。
cifar網絡的定義在examples/cifar10目錄下,訓練的過程十分簡單。
(以下命令均在caffe默認根目錄下運行,下同)
1、獲取訓練數據
cd $CAFFE_ROOT ./data/cifar10/get_cifar10.sh ./examples/cifar10/create_cifar10.sh
2、開始訓練
cd $CAFFE_ROOT ./examples/cifar10/train_quick.sh
3、訓練完成后我們會得到:
cifar10_quick_iter_4000.caffemodel.h5
cifar10_quick_iter_4000.solverstate.h5
此時,我們就訓練得到了模型,用於后面的分類。
4、下面我們使用模型來分類新數據
先直接用一下別人的模型分類試一下:(默認用的ImageNet的模型)
python python/classify.py examples/images/cat.jpg foo
下面我們來指定自己的模型進行分類:
python python/classify.py --model_def examples/cifar10/cifar10_quick.prototxt --pretrained_model examples/cifar10/cifar10_quick_iter_4000.caffemodel.h5 --center_only examples/images/cat.jpg foo
上面這句話的意思是,使用cifar10_quick.prototxt網絡 + cifar10_quick_iter_4000.caffemodel.h5模型,對examples/images/cat.jpg圖片進行分類。
默認的classify腳本不會直接輸出結果,而是會把結果輸入到foo文件里,不太直觀,這里我在網上找了一個修改版,添加了一些參數,可以輸出概率最高的分類。
替換python/classify.py,下載地址:http://download.csdn.net/detail/caisenchuan/9513196
這個腳本添加了兩個參數,可以指定labels_file,然后可以直接把分類結果輸出出來:
python python/classify.py --print_results --model_def examples/cifar10/cifar10_quick.prototxt --pretrained_model examples/cifar10/cifar10_quick_iter_4000.caffemodel.h5 --labels_file data/cifar10/cifar10_words.txt --center_only examples/images/cat.jpg foo
輸出結果:
Loading file: examples/images/cat.jpg Classifying 1 inputs. predict 3 inputs. Done in 0.02 s. Predictions : [[ 0.03903743 0.00722749 0.04582177 0.44352672 0.01203315 0.11832549 0.02335102 0.25013766 0.03541689 0.02512246]] python/classify.py:176: FutureWarning: sort(columns=....) is deprecated, use sort_values(by=.....) labels = labels_df.sort('synset_id')['name'].values [('cat', '0.44353'), ('horse', '0.25014'), ('dog', '0.11833'), ('bird', '0.04582'), ('airplane', '0.03904')] 上面標明了各個分類的順序和置信度
Saving results into foo
Tips
最后,總結一下訓練一個網絡用到的相關文件:
cifar10_quick_solver.prototxt:方案配置,用於配置迭代次數等信息,訓練時直接調用caffe train指定這個文件,就會開始訓練
cifar10_quick_train_test.prototxt:訓練網絡配置,用來設置訓練用的網絡,這個文件的名字會在solver.prototxt里指定
cifar10_quick_iter_4000.caffemodel.h5:訓練出來的模型,后面就用這個模型來做分類
cifar10_quick_iter_4000.solverstate.h5:也是訓練出來的,應該是用來中斷后繼續訓練用的文件
cifar10_quick.prototxt:分類用的網絡