從零開始學習MXnet(一)


  最近工作要開始用到MXnet,然而MXnet的文檔寫的實在是.....所以在這記錄點東西,方便自己,也方便大家。

  我覺得搞清楚一個框架怎么使用,第一步就是用它來訓練自己的數據,這是個很關鍵的一步。 

一、MXnet數據預處理

  整個數據預處理的代碼都集成在了toosl/im2rec.py中了,這個首先要造出一個list文件,lst文件有三列,分別是index label 圖片路徑。如下圖所示:

       

  我這個label是瞎填的,所以都是0。另外最新的MXnet上面的im2rec是有問題的,它生成的list所有的index都是0,不過據說這個index沒什么用.....但我還是改了一下。把yield生成器換成直接append即可。

  執行的命令如下:

    sudo python im2rec.py --list=True /home/erya/dhc/result/try /home/erya/dhc/result/ --recursive=True --shuffle=true --train-ratio=0.8 

  每個參數的意義在代碼內部都可以查到,簡單說一下這里用到的:--list=True說明這次的目的是make list,后面緊跟的是生成的list的名字的前綴,我這里是加了路徑,然后是圖片所在文件夾的路徑,recursive是是否迭代的進入文件夾讀取圖片,--train-ratio則表示train和val在數據集中的比例。

執行上面的命令后,會得到三個文件:

 

    然后再執行下面的命令生成最后的rec文件:

  sudo python im2rec.py /home/erya/dhc/result/try_val.lst  /home/erya/dhc/result --quality=100 

  以及,sudo python im2rec.py /home/erya/dhc/result/try_train.lst  /home/erya/dhc/result --quality=100 

 來生成相應的lst文件的rec文件,參數意義太簡單就不說了..看着就明白,result是我存放圖片的目錄。

 

  這樣最終就完成了數據的預處理,簡單的說,就是先生成lst文件,這個其實完全可以自己做,而且后期我做segmentation的時候,label就是圖片了..

 

二、非常簡單的小demo

先上代碼:

  

 1 import mxnet as mx
 2 import logging
 3 import numpy as np
 4 
 5 logger = logging.getLogger()
 6 logger.setLevel(logging.DEBUG)#暫時不需要管的log
 7 def ConvFactory(data, num_filter, kernel, stride=(1,1), pad=(0, 0), act_type="relu"):
 8     conv = mx.symbol.Convolution(data=data, workspace=256,
 9                                  num_filter=num_filter, kernel=kernel, stride=stride, pad=pad)
10     return conv   #我把這個刪除到只有一個卷積的操作
11 def DownsampleFactory(data, ch_3x3):
12     # conv 3x3
13     conv = ConvFactory(data=data, kernel=(3, 3), stride=(2, 2), num_filter=ch_3x3, pad=(1, 1))
14     # pool
15     pool = mx.symbol.Pooling(data=data, kernel=(3, 3), stride=(2, 2), pool_type='max')
16     # concat
17     concat = mx.symbol.Concat(*[conv, pool])
18     return concat
19 def SimpleFactory(data, ch_1x1, ch_3x3):
20     # 1x1
21     conv1x1 = ConvFactory(data=data, kernel=(1, 1), pad=(0, 0), num_filter=ch_1x1)
22     # 3x3
23     conv3x3 = ConvFactory(data=data, kernel=(3, 3), pad=(1, 1), num_filter=ch_3x3)
24     #concat
25     concat = mx.symbol.Concat(*[conv1x1, conv3x3])
26     return concat
27 if __name__ == "__main__":
28     batch_size = 1
29     train_dataiter = mx.io.ImageRecordIter(
30         shuffle=True,
31         path_imgrec="/home/erya/dhc/result/try_train.rec",
32         rand_crop=True,
33         rand_mirror=True,
34         data_shape=(3,28,28),
35         batch_size=batch_size,
36         preprocess_threads=1)#這里是使用我們之前的創造的數據,簡單的說就是要自己寫一個iter,然后把相應的參數填進去。
37     test_dataiter = mx.io.ImageRecordIter(
38         path_imgrec="/home/erya/dhc/result/try_val.rec",
39         rand_crop=False,
40         rand_mirror=False,
41         data_shape=(3,28,28),
42         batch_size=batch_size,
43         round_batch=False,
44         preprocess_threads=1)#同理
45     data = mx.symbol.Variable(name="data")
46     conv1 = ConvFactory(data=data, kernel=(3,3), pad=(1,1), num_filter=96, act_type="relu")
47     in3a = SimpleFactory(conv1, 32, 32)
48     fc = mx.symbol.FullyConnected(data=in3a, num_hidden=10)
49     softmax = mx.symbol.SoftmaxOutput(name='softmax',data=fc)#上面就是定義了一個巨巨巨簡單的結構
50     # For demo purpose, this model only train 1 epoch
51     # We will use the first GPU to do training
52     num_epoch = 1
53     model = mx.model.FeedForward(ctx=mx.gpu(), symbol=softmax, num_epoch=num_epoch,
54                              learning_rate=0.05, momentum=0.9, wd=0.00001) #將整個model訓練的架構定下來了,類似於caffe里面solver所做的事情。
55 
56 # we can add learning rate scheduler to the model
57 # model = mx.model.FeedForward(ctx=mx.gpu(), symbol=softmax, num_epoch=num_epoch,
58 #                              learning_rate=0.05, momentum=0.9, wd=0.00001,
59 #                              lr_scheduler=mx.misc.FactorScheduler(2))
60 model.fit(X=train_dataiter,
61           eval_data=test_dataiter,
62           eval_metric="accuracy",
63           batch_end_callback=mx.callback.Speedometer(batch_size))#開跑數據。

 

 

 

 

  

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM