折騰了幾天,爬了大大小小若干的坑,特記錄如下。代碼在最后面。
環境:
Python3.6.4 + TensorFlow 1.5.1 + Win7 64位 + I5 3570 CPU
方法:
先用MNIST手寫數字庫對CNN(卷積神經網絡)進行訓練,准確度達到98%以上時,再准備獨家手寫數字10個、畫圖軟件編輯的數字10個共計20個,讓訓練好的CNN進行識別,考察其識別准確度。
調試代碼:
坑1:ModuleNotFoundError: No module named 'google'
解決:pip install protobuf
不用翻牆
坑2:ModuleNotFoundError: No module named 'absl'
解決:pip install absl-py
坑3:tensorflow.python.framework.errors_impl.InvalidArgumentError: You must feed a value for placeholder tensor 'Placeholder_2' with dtype float
解決:這個問題折騰我好久,但是最終的解決方法很無語。。。
原來的代碼是這樣的:
output = sess.run(y_conv, feed_dict={x: ndarrayImgs}) # ndarrayImgs為自己的樣本圖片數據
查了不少資料,最后發現是自己少寫了一個參數 /笑哭/笑哭, 寫成這樣就沒問題了:
output = sess.run(y_conv, feed_dict={x: ndarrayImgs, keep_prob:1.0})
代碼調通了之后,大坑來了:訓練后的CNN識別自己的手寫數字和用畫圖軟件編輯出來的數字,正確率只有70%左右,慘不忍睹。
考慮到上面20個數字都是五官端正的,那么准確率低多半是其它原因。調試思路:
1)檢查20個數字圖片的格式:灰度圖片,黑底白字,28x28像素。沒問題。
2)用MNIST自帶的測試數據進行測試,正確率95%左右。說明CNN訓練的還算到位。
3)去網上搜索,終於在知乎里發現了一條回復:MNIST的數字都是20*20大小,圖片大小28*28。把自己的圖片伸縮到20*20大小,然后平移到28*28的中心就可以了。
納尼??原來數字輪廓大小是20x20像素,這個細節我沒注意到。開動PS,利用裁切和調整畫布功能,對圖片處理了一番。
附:MNIST數據庫及其說明 http://yann.lecun.com/exdb/mnist/
再次測試,正確率在85-90%左右,有明顯提升。
然而仔細分析發現,有幾個數字的識別結果經常出錯,分別是手寫的6、7、9。將這幾個數字的圖片和樣本庫中的圖片對比了一下,猜想可能是這幾個圖片中的數字的線條有些細,於是用PS又調整了一下,把線條變粗,結果識別正確率可以達到95-100%了(奇怪的是,數字1-5線條也細,為何能准確識別?)
調試過程記錄完畢,放代碼。使用時注意系統環境和相關軟件版本,如開頭所述。
這個代碼在每次識別前都會先訓練,在CPU上進行計算真是痛苦。。。以后打算將訓練和預測分開,訓練好的模型保存起來,預測的時候直接加載,這樣能省不少時間。
代碼沒優化,有點凌亂,建議移步去看我的《使用TensorFlow的卷積神經網絡識別手寫數字》1、2、3系列。
1 import matplotlib 2 import matplotlib.pyplot as plt 3 import matplotlib.cm as cm 4 import pylab 5 from tensorflow.examples.tutorials.mnist import input_data 6 7 8 def showMnistImg(nBytes): 9 imgBytes = nBytes.reshape((28, 28)) 10 print(imgBytes) 11 plt.figure(figsize=(2.8,2.8)) 12 #plt.grid() #開啟網格 13 plt.imshow(imgBytes, cmap=cm.gray) 14 pylab.show() 15 16 17 def MaxMinNormalization(x,Max,Min): 18 x = (x - Min) / (Max - Min); 19 return x; 20 21 22 def loadHandWritingImage(strFilePath): 23 im=Image.open(strFilePath, 'r') 24 ndarrayImg = np.array(im.convert("L"), dtype='float64') 25 26 return ndarrayImg 27 28 def normalizeImage(ndarrayImg, maxVal = 255, minVal = 0): 29 w, h = ndarrayImg.shape[0], ndarrayImg.shape[1] 30 for i in range(w): 31 for j in range(h): 32 ndarrayImg[i,j] = MaxMinNormalization(ndarrayImg[i,j], maxVal, minVal) #??? 33 34 return ndarrayImg 35 36 mnist = input_data.read_data_sets('MNIST_data', one_hot=True) 37 38 # 單個手寫數字的784個字節的灰度值,浮點數,范圍[0,1) 39 print('type(mnist.train.images): ', type(mnist.train.images)) # <class 'numpy.ndarray'> 40 print('mnist.train.images.shape: ', mnist.train.images.shape) 41 ##print(mnist.train.images[0]) 42 ##showMnistImg(mnist.train.images[0]) 43 44 45 # 單個手寫數字的標簽 46 # 一個one-hot向量除了某一位的數字是1以外其余各維度數字都是0 47 # 數字n將表示成一個只有在第n維度(從0開始)數字為1的10維向量。 48 #print('type(mnist.train.labels[0]): ', type(mnist.train.labels[0]))# <class 'numpy.ndarray'> 49 #print(mnist.train.labels[19]) # [0. 0. 0. 0. 0. 0. 0. 1. 0. 0.] 50 51 52 53 #構造自己的手寫圖片集合,作為test。 cnblogs.com/hatemath 54 from PIL import * 55 import numpy as np 56 import tensorflow as tf 57 58 # 構建測試樣本集合 59 files = ['0.png', '1.png', '2.png', '3.png', '4.png', '5.png', '6.png', '7.png', '8.png', '9.png', 60 '00.png', '11.png', '22.png', '33.png', '44.png', '55.png', '66.png', '77.png', '88.png', '99.png'] 61 62 ndarrayImgs = np.zeros((len(files), 784)) # x行784列 63 #print('type(ndarrayImgs): ', type(ndarrayImgs)) 64 #print('ndarrayImgs.shape: ', ndarrayImgs.shape) 65 66 index = 0 67 for file in files: 68 69 # 加載圖片 70 ndarrayImg = loadHandWritingImage('numbers/' + file) 71 72 #print('type(ndarrayImg): ', type(ndarrayImg)) 73 #print(ndarrayImg) 74 75 # 歸一化 76 normalizeImage(ndarrayImg) 77 78 # 轉為1x784的數組 79 ndarrayImg = ndarrayImg.reshape((1, 784)) 80 #print('type(ndarrayImg): ', type(ndarrayImg)) 81 #print('ndarrayImg.shape: ', ndarrayImg.shape) 82 83 # 放到測試樣本集中 84 ndarrayImgs[index] = ndarrayImg 85 index = index + 1 86 87 88 # 構建測試樣本的實際值集合,用於計算正確率 89 ndarrayLabels = np.array([ [1., 0., 0., 0., 0., 0., 0., 0., 0., 0.], 90 [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.], 91 [0., 0., 1., 0., 0., 0., 0., 0., 0., 0.], 92 [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.], 93 [0., 0., 0., 0., 1., 0., 0., 0., 0., 0.], 94 [0., 0., 0., 0., 0., 1., 0., 0., 0., 0.], 95 [0., 0., 0., 0., 0., 0., 1., 0., 0., 0.], 96 [0., 0., 0., 0., 0., 0., 0., 1., 0., 0.], 97 [0., 0., 0., 0., 0., 0., 0., 0., 1., 0.], 98 [0., 0., 0., 0., 0., 0., 0., 0., 0., 1.], 99 [1., 0., 0., 0., 0., 0., 0., 0., 0., 0.], 100 [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.], 101 [0., 0., 1., 0., 0., 0., 0., 0., 0., 0.], 102 [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.], 103 [0., 0., 0., 0., 1., 0., 0., 0., 0., 0.], 104 [0., 0., 0., 0., 0., 1., 0., 0., 0., 0.], 105 [0., 0., 0., 0., 0., 0., 1., 0., 0., 0.], 106 [0., 0., 0., 0., 0., 0., 0., 1., 0., 0.], 107 [0., 0., 0., 0., 0., 0., 0., 0., 1., 0.], 108 [0., 0., 0., 0., 0., 0., 0., 0., 0., 1.] 109 ]) 110 print('type(ndarrayLabels): ', type(ndarrayLabels)) 111 112 113 #print(ndarrayImgs[3]) 114 ##showMnistImg(ndarrayImgs[3]) 115 #print(ndarrayLabels[3]) 116 117 118 # 下面開始CNN相關 119 120 def conv2d(x, W): 121 return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME') 122 123 def max_pool_2x2(x): 124 return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], 125 strides=[1, 2, 2, 1], padding='SAME') 126 127 128 def weight_variable(shape): 129 initial = tf.truncated_normal(shape, stddev=0.1) 130 return tf.Variable(initial) 131 132 def bias_variable(shape): 133 initial = tf.constant(0.1, shape=shape) 134 return tf.Variable(initial) 135 136 137 x = tf.placeholder(tf.float32, shape=[None, 784]) 138 y_ = tf.placeholder(tf.float32, shape=[None, 10]) 139 140 141 W_conv1 = weight_variable([5, 5, 1, 32]) 142 b_conv1 = bias_variable([32]) 143 144 x_image = tf.reshape(x, [-1, 28, 28, 1]) 145 146 h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1) 147 h_pool1 = max_pool_2x2(h_conv1) 148 149 150 W_conv2 = weight_variable([5, 5, 32, 64]) 151 b_conv2 = bias_variable([64]) 152 153 h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2) 154 h_pool2 = max_pool_2x2(h_conv2) 155 156 157 158 W_fc1 = weight_variable([7 * 7 * 64, 1024]) 159 b_fc1 = bias_variable([1024]) 160 161 h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64]) 162 h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1) 163 164 165 keep_prob = tf.placeholder(tf.float32) 166 h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob) 167 168 169 W_fc2 = weight_variable([1024, 10]) 170 b_fc2 = bias_variable([10]) 171 172 y_conv = tf.matmul(h_fc1_drop, W_fc2) + b_fc2 173 #print(y_conv) 174 175 176 177 cross_entropy = tf.reduce_mean( 178 tf.nn.softmax_cross_entropy_with_logits_v2(labels=y_, logits=y_conv)) 179 train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy) 180 correct_prediction = tf.equal(tf.argmax(y_conv, 1), tf.argmax(y_, 1)) 181 accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 182 183 with tf.Session() as sess: 184 sess.run(tf.global_variables_initializer()) 185 for i in range(1000): 186 batch = mnist.train.next_batch(50) 187 188 if i % 100 == 0: 189 train_accuracy = accuracy.eval(feed_dict={ 190 x: batch[0], y_: batch[1], keep_prob: 1.0}) 191 print('step %d, training accuracy %g' % (i, train_accuracy)) 192 if(train_accuracy>0.98): 193 break 194 195 train_step.run(feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5}) 196 197 198 199 print('測試Mnist test數據集 准確率 %g' % accuracy.eval(feed_dict={ 200 x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0})) 201 202 # 測試耗時 203 import time 204 start = time.time() 205 accu = accuracy.eval(feed_dict={x: ndarrayImgs, y_: ndarrayLabels, keep_prob: 1.0}) 206 end = time.time() 207 208 print('識別zzh手寫數據%d個, 准確率為 %g, 每個耗時%g秒' % (len(ndarrayImgs), accu, (end-start)/len(ndarrayImgs))) 209 210 output = sess.run(y_conv, feed_dict={x: ndarrayImgs, keep_prob:1.0}) 211 print('預測值:', output.argmax(axis=1)) # axis:0表示按列,1表示按行 212 print('實際值:', ndarrayLabels.argmax(axis=1))
貼2次運行結果,供參考:
Extracting MNIST_data\train-images-idx3-ubyte.gz Extracting MNIST_data\train-labels-idx1-ubyte.gz Extracting MNIST_data\t10k-images-idx3-ubyte.gz Extracting MNIST_data\t10k-labels-idx1-ubyte.gz type(mnist.train.images): <class 'numpy.ndarray'> mnist.train.images.shape: (55000, 784) type(ndarrayLabels): <class 'numpy.ndarray'> step 0, training accuracy 0.14 step 100, training accuracy 0.86 step 200, training accuracy 0.82 step 300, training accuracy 0.98 測試Mnist test數據集 准確率 0.9213 識別zzh手寫數據20個, 准確率為 0.9, 每個耗時0.000750029秒 預測值: [0 1 2 3 4 5 6 1 8 9 0 1 2 3 4 5 6 2 8 9] 實際值: [0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9] >>>
Extracting MNIST_data\train-images-idx3-ubyte.gz Extracting MNIST_data\train-labels-idx1-ubyte.gz Extracting MNIST_data\t10k-images-idx3-ubyte.gz Extracting MNIST_data\t10k-labels-idx1-ubyte.gz type(mnist.train.images): <class 'numpy.ndarray'> mnist.train.images.shape: (55000, 784) type(ndarrayLabels): <class 'numpy.ndarray'> step 0, training accuracy 0.14 step 100, training accuracy 0.84 step 200, training accuracy 0.92 step 300, training accuracy 0.88 step 400, training accuracy 0.96 step 500, training accuracy 0.98 測試Mnist test數據集 准確率 0.9445 識別zzh手寫數據20個, 准確率為 1, 每個耗時0.000779998秒 預測值: [0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9] 實際值: [0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9] >>>
總結:
1) CNN雖然是個神器,但是要想提高手寫數字識別率,除了CNN的訓練外,還要在手寫圖片上做足前戲,啊呸,做足預處理,要把手寫圖片按照MNIST規范進行調整,畢竟訓練的樣本就是按照那些規范來的。
2) 再次重申一下圖片規范:灰度圖片,黑底白字,數字的外圍輪廓大小是20x20像素,圖片總體的大小是28x28像素。自動化的預處理可以用opencv來做。
3) 用CPU做訓練,非常慢。我的機器上,訓練500次耗時1分鍾,每次調試都這么等,太浪費時間了。考慮保存/加載模型的方案,或者搞一塊N卡,用CUDA計算應該會快很多。