caffe簡易上手指南(一)—— 運行cifar例子


簡介

caffe是一個友好、易於上手的開源深度學習平台,主要用於圖像的相關處理,可以支持CNN等多種深度學習網絡。

基於caffe,開發者可以方便快速地開發簡單的學習網絡,用於分類、定位等任務,也可以用於科研,在其源碼基礎上進行修改,實現自己的算法。

本文的主要目的,是介紹caffe的基本使用方法,希望通過本文,能讓普通的工程師可以使用caffe訓練自己的簡單模型。

本文主要包括以下內容:運行caffe的例子訓練cifar訓練集、使用別人定義好的網絡訓練自己的數據、使用訓練好的模型fine tune自己的數據。

 

 

背景知識簡介

深度學習是機器學習的一個分支,主要目標在於通過學習的方法,解決以往普通編程無法解決的問題,例如:圖像識別、文字識別等等。

機器學習里的“學習”,指通過向程序輸入經驗數據,通過若干次“迭代”,不斷改進算法參數,最終能夠獲得“模型”,使用新數據輸入模型,計算得出想要的結果。

例如圖像分類任務中,經驗數據是圖片和對應的文字,訓練出模型后,將新圖片使用模型運算,就可以知道其對應的類別。

以上只是簡單介紹,這里還是建議先學習機器學習、卷積神經網絡的相關基礎知識。

 

 

安裝

這一部分網上有不少教程,這里就略掉,另外,我是用docker的鏡像直接安裝的,網上可以直接搜到帶caffe的docker鏡像。好處是省去安裝環境的時間,缺點是后面設置文件會麻煩一些,建議從長計議還是直接安裝在電腦上。

 

 

訓練cifar訓練集

cifar是一個常見的圖像分類訓練集,包括上萬張圖片及20個分類,caffe提供了一個網絡用於分類cifar數據集。

cifar網絡的定義在examples/cifar10目錄下,訓練的過程十分簡單。

(以下命令均在caffe默認根目錄下運行,下同)

 

1、獲取訓練數據

cd $CAFFE_ROOT
./data/cifar10/get_cifar10.sh
./examples/cifar10/create_cifar10.sh

 

2、開始訓練

cd $CAFFE_ROOT
./examples/cifar10/train_quick.sh

 

3、訓練完成后我們會得到:

  cifar10_quick_iter_4000.caffemodel.h5

  cifar10_quick_iter_4000.solverstate.h5

  此時,我們就訓練得到了模型,用於后面的分類。

 

4、下面我們使用模型來分類新數據

先直接用一下別人的模型分類試一下:(默認用的ImageNet的模型)

python python/classify.py examples/images/cat.jpg foo

 

下面我們來指定自己的模型進行分類:

python python/classify.py --model_def examples/cifar10/cifar10_quick.prototxt --pretrained_model examples/cifar10/cifar10_quick_iter_4000.caffemodel.h5 --center_only  examples/images/cat.jpg foo

上面這句話的意思是,使用cifar10_quick.prototxt網絡 + cifar10_quick_iter_4000.caffemodel.h5模型,對examples/images/cat.jpg圖片進行分類。

 

默認的classify腳本不會直接輸出結果,而是會把結果輸入到foo文件里,不太直觀,這里我在網上找了一個修改版,添加了一些參數,可以輸出概率最高的分類。

替換python/classify.py,下載地址:http://download.csdn.net/detail/caisenchuan/9513196

 

這個腳本添加了兩個參數,可以指定labels_file,然后可以直接把分類結果輸出出來:

python python/classify.py --print_results --model_def examples/cifar10/cifar10_quick.prototxt --pretrained_model examples/cifar10/cifar10_quick_iter_4000.caffemodel.h5 --labels_file data/cifar10/cifar10_words.txt  --center_only  examples/images/cat.jpg foo

輸出結果:

Loading file: examples/images/cat.jpg
Classifying 1 inputs.
predict 3 inputs.
Done in 0.02 s.
Predictions : [[ 0.03903743  0.00722749  0.04582177  0.44352672  0.01203315  0.11832549
   0.02335102  0.25013766  0.03541689  0.02512246]]
python/classify.py:176: FutureWarning: sort(columns=....) is deprecated, use sort_values(by=.....)
  labels = labels_df.sort('synset_id')['name'].values
[('cat', '0.44353'), ('horse', '0.25014'), ('dog', '0.11833'), ('bird', '0.04582'), ('airplane', '0.03904')]
上面標明了各個分類的順序和置信度
Saving results into foo

 

Tips

最后,總結一下訓練一個網絡用到的相關文件:

cifar10_quick_solver.prototxt:方案配置,用於配置迭代次數等信息,訓練時直接調用caffe train指定這個文件,就會開始訓練

cifar10_quick_train_test.prototxt:訓練網絡配置,用來設置訓練用的網絡,這個文件的名字會在solver.prototxt里指定

cifar10_quick_iter_4000.caffemodel.h5:訓練出來的模型,后面就用這個模型來做分類

cifar10_quick_iter_4000.solverstate.h5:也是訓練出來的,應該是用來中斷后繼續訓練用的文件

cifar10_quick.prototxt:分類用的網絡

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM