caffe簡易上手指南(二)—— 訓練我們自己的數據


訓練我們自己的數據

 

本篇繼續之前的教程,下面我們嘗試使用別人定義好的網絡,來訓練我們自己的網絡。

1、准備數據

首先很重要的一點,我們需要准備若干種不同類型的圖片進行分類。這里我選擇從ImageNet上下載了3個分類的圖片(Cat,Dog,Fish)。

圖片需要分兩批:訓練集(train)、測試集(test),一般訓練集與測試集的比例大概是5:1以上,此外每個分類的圖片也不能太少,我這里每個分類大概選了5000張訓練圖+1000張測試圖。

找好圖片以后,需要准備以下文件:

words.txt:分類序號與分類對應關系(注意:要從0開始標注

0 cat
1 dog
2 fish

 

train.txt:標明訓練圖片路徑及其對應分類,路徑和分類序號直接用空格分隔,最好隨機打亂一下圖片

/opt/caffe/examples/my_simple_image/data/cat_train/n02123045_4416.JPEG 0
/opt/caffe/examples/my_simple_image/data/cat_train/n02123045_3568.JPEG 0
/opt/caffe/examples/my_simple_image/data/fish_train/n02512053_4451.JPEG 2
/opt/caffe/examples/my_simple_image/data/cat_train/n02123045_3179.JPEG 0
/opt/caffe/examples/my_simple_image/data/cat_train/n02123045_6956.JPEG 0
/opt/caffe/examples/my_simple_image/data/cat_train/n02123045_10143.JPEG 0
......

 

val.txt:標明測試圖片路徑及其對應分類

/opt/caffe/examples/my_simple_image/data/dog_val/n02084071_12307.JPEG 1
/opt/caffe/examples/my_simple_image/data/dog_val/n02084071_10619.JPEG 1
/opt/caffe/examples/my_simple_image/data/cat_val/n02123045_13360.JPEG 0
/opt/caffe/examples/my_simple_image/data/cat_val/n02123045_13060.JPEG 0
/opt/caffe/examples/my_simple_image/data/cat_val/n02123045_11859.JPEG 0
......

 

2、生成lmdb文件

lmdb是caffe使用的一種輸入數據格式,相當於我們把圖片及其分類重新整合一下,變成一個數據庫輸給caffe訓練。

這里我們使用caffenet的create_imagenet.sh文件修改,主要是重新指定一下路徑:

EXAMPLE=examples/my_simple_image/ DATA=examples/my_simple_image/data/ TOOLS=build/tools

TRAIN_DATA_ROOT=/ VAL_DATA_ROOT=/ # 這里我們打開resize,需要把所有圖片尺寸統一 RESIZE=true if $RESIZE; then RESIZE_HEIGHT=256 RESIZE_WIDTH=256 else RESIZE_HEIGHT=0 RESIZE_WIDTH=0 fi ....... echo "Creating train lmdb..." GLOG_logtostderr=1 $TOOLS/convert_imageset \ --resize_height=$RESIZE_HEIGHT \ --resize_width=$RESIZE_WIDTH \ --shuffle \ $TRAIN_DATA_ROOT \ $DATA/train.txt \ $EXAMPLE/ilsvrc12_train_lmdb  #生成的lmdb路徑 echo "Creating val lmdb..." GLOG_logtostderr=1 $TOOLS/convert_imageset \ --resize_height=$RESIZE_HEIGHT \ --resize_width=$RESIZE_WIDTH \ --shuffle \ $VAL_DATA_ROOT \ $DATA/val.txt \ $EXAMPLE/ilsvrc12_val_lmdb #生成的lmdb路徑
echo "Done."

 

3、生成mean_file

下面我們用lmdb生成mean_file,用於訓練(具體做啥用的我還沒研究。。。)

這里也是用imagenet例子的腳本:

EXAMPLE=examples/my_simple_image
DATA=examples/my_simple_image TOOLS=build/tools $TOOLS/compute_image_mean $EXAMPLE/ilsvrc12_train_lmdb $DATA/imagenet_mean.binaryproto echo "Done."

 

 

4、修改solver、train_val配置文件

這里我們可以選用cifar的網絡,也可以用imagenet的網絡,不過后者的網絡結構更復雜一些,為了學習,我們就用cifar的網絡來改。

把cifar的兩個配置文件拷過來:

cifar10_quick_solver.prototxt
cifar10_quick_train_test.prototxt

首先修改cifar10_quick_train_test.prototxt的路徑以及輸出層數量(標注出黑體的部分):

name: "CIFAR10_quick"
layer {
  name: "cifar" type: "Data" top: "data" top: "label" include { phase: TRAIN } transform_param { mean_file: "examples/my_simple_image/imagenet_mean.binaryproto" } data_param {
source: "examples/my_simple_image/ilsvrc12_train_lmdb" batch_size: 50 #一次訓練的圖片數量,一般指定50也夠了 backend: LMDB } } layer { name: "cifar" type: "Data" top: "data" top: "label" include { phase: TEST } transform_param { mean_file: "examples/my_simple_image/imagenet_mean.binaryproto" } data_param { source: "examples/my_simple_image/ilsvrc12_val_lmdb" batch_size: 50 #一次訓練的圖片數量 backend: LMDB } }
..........
layer { name: "ip2" type: "InnerProduct" bottom: "ip1" top: "ip2" .......... inner_product_param { num_output: 3 #輸出層數量,就是你要分類的個數 weight_filler { type: "gaussian" std: 0.1 } bias_filler { type: "constant" } } } ......

 

cifar10_quick_solver.prototxt的修改根據自己的實際需要:

net: "examples/my_simple_image/cifar/cifar10_quick_train_test.prototxt" #網絡文件路徑
test_iter: 20 #測試執行的迭代次數
test_interval: 10 #迭代多少次進行測試 base_lr: 0.001 #迭代速率,這里我們改小了一個數量級,因為數據比較少
momentum: 0.9 weight_decay: 0.004 lr_policy: "fixed" #采用固定學習速率的模式display: 1 #迭代幾次就顯示一下信息,這里我為了及時跟蹤效果,改成1 max_iter: 4000 #最大迭代次數 snapshot: 1000 #迭代多少次生成一次快照 snapshot_prefix: "examples/my_simple_image/cifar/cifar10_quick" #快照路徑和前綴 solver_mode: CPU #CPU或者GPU

 

5、開始訓練

運行下面的命令,開始訓練(為了方便可以做成腳本)

./build/tools/caffe train --solver=examples/my_simple_image/cifar/cifar10_quick_solver.prototxt

 

6、小技巧

網絡的配置和訓練其實有一些小技巧。

- 訓練過程中,正確率時高時低是很正常的現象,但是總體上是要下降的

- 觀察loss值的趨勢,如果迭代幾次以后一直在增大,最后變成nan,那就是發散了,需要考慮減小訓練速率,或者是調整其他參數

- 數據不能太少,如果太少的話很容易發散

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM