這里介紹一種深度殘差網(deep residual networks)的訓練過程:
1、通過下面的地址下載基於python的訓練代碼:
https://github.com/dnlcrl/deep-residual-networks-pyfunt
2、這些訓練代碼需要和pydataset包。下面介紹這兩個包的安裝方法。
(1)pyfunt需要安裝。
用命令:pip install git+git://github.com/dnlcrl/PyFunt.git
進行下載安計。
安裝時numpy需要1.11及以上,但筆者機器上的numpy為1.10,因此,還采用了如下命令對其進行升級:
pip install numpy --upgrade
(2)pydataset
pydataset主要用於數據預處理。 由於pydataset包需要cv2(opencv for python),但無法用pip安裝cv2(該項目好像停止了)。因此只有通過下載opencv3.0(我下載的是opencv-3.0.0.exe),然后解壓,在解壓后的\build\python\2.7\x64目錄下將cv2.pyd文件拷貝到python的site-packages目錄下,注意,筆者機器上安裝的是64位的python2.7,所以選擇的是這個路徑,不同的python版本和平台,其路徑不一樣。
然后需要設置opencv的路徑。具體設置方法見下圖(注意,筆者的opencv是安裝在F:\tool\opencv,不同的opencv安裝路徑,其設置內容不一樣):
然后下載通過: https://github.com/dnlcrl//PyDatSet 下載PyDatSet,解壓,在所解壓的目錄python setup.py install來安裝。
3、下載cafir10 for python的數據集,然后解壓,記住這個目錄,在第一次運行train.py時,需要輸入這個目錄的全路徑。
4、由於源代碼有bug,需要修改源代碼。
主要修改的源代碼有:
(1) train.py中的 NUM_TRAIN = 40000 (源代碼是5000);
(2)cifar10.py文件(在python安裝目錄下的lib\site-packages\pydatset目錄下)中的load函數中的 for b in range(1, 5);(源代碼為6)
(3)cifar10.py文件(在python安裝目錄下的lib\site-packages\pydatset目錄下)中的load_CIFAR_batch函數中的with open(filename, 'rb') as f;(源代碼為'r');
5、通過執行python train.py就可以進行訓練了。