比如你在mnist的prototxt中定義圖輸入是單通道的,也就是channel=1,然后如果直接調用classify.py腳本來測試的話,是會報錯,錯誤跟一下類似。
Source param shape is 128 3 32 32; target param shape is 128 1 32 32.
意思就是網絡要求輸入是1 channel,而你讀入的數據是3 channels。
即使你再調用這個腳本之前,已經把圖轉換成灰度圖了,也是不行。
那是因為caffe.io.load_image讀入數據的時候,總是會把數據轉成3 channels。
所以,我們需要換一種方式讀入數據。
具體做法
- 找到classify.py中
inputs = [caffe.io.load_image(im_f) for im_f in glob.glob(args.input_file + '/*.' + args.ext)]
- 替換成
tmp = []
for _ in inputs:
img = skimage.img_as_float(skimage.io.imread(_)).astype(np.float32)
if len(img.shape) == 2:
# 設置channel為1
img = img.reshape(img.shape[0], img.shape[1], 1)
tmp.append(img)
inputs = tmp
這里是修改的測試集是從一個目錄讀入的,如果測試集是單獨的一張圖,修改方式也類似。