python: 神經網絡實現MNIST圖像識別


神經網絡輸入層神經單元個數:784      (圖像大小28*28)

               輸出層                      :10        (10個類別分類,即10個數字)

              隱藏層個數                 :2

     第1個隱藏層的神經單元數   :50

     第2個隱藏層的神經單元數   :100

 

先定義get_data()、init_network()predict()這3個函數:

 

 1 def get_data():
 2     (x_train,t_train),(x_test,t_test)=load_mnist(nomalize=True,flatten=True,one_hot_label=False)
 3     return x_test,t_test
 4 
 5 def init_natwork():
 6     with open("sample_weight.pkl",'rb') as f:
 7         network=pickle.load(f)
 8     return network
 9 
10 def predict(network,x):
11     W1,W2,W3=network['W1'],network['W2'],network['W3']
12     b1,b2,b3=network['b1'],network['b2'],network['b3']
13     a1 = np.dot(x, W1) + b1
14     z1 = sigmoid(a1)
15     a2 = np.dot(z1, W2) + b2
16     z2 = sigmoid(a2)
17     a3 = np.dot(z2, W3) + b3
18     y = softmax(a3)
19     return y

init_network() 會讀入保存在 pickle 文件 sample_weight.pkl 中的學習到的權重參數 {8[因為之前我們假設學習已經完成,所以學習到的參數被保存下來。假設保存在 sample_weight.pkl 文件中,在推理階段,我們直接加載這些已經學習到的參數。——譯者注]}。這個文件中以字典變量的形式保存了權重和偏置參數。剩余的 2 個函數,和前面介紹的代碼實現基本相同,無需再解釋。現在,我們用這 3 個函數來實現神經網絡的推理處理。然后,評價它的識別精度(accuracy),即能在多大程度上正確分類。

 

 

 

 1 x, t = get_data()
 2 network = init_network()
 3 
 4 accuracy_cnt = 0
 5 for i in range(len(x)):
 6     y = predict(network, x[i])
 7     p = np.argmax(y) # 獲取概率最高的元素的索引
 8     if p == t[i]:
 9         accuracy_cnt += 1
10 
11 print("Accuracy:" + str(float(accuracy_cnt) / len(x)))

 

首先獲得 MNIST 數據集,生成網絡。接着,用 for 語句逐一取出保存在 x 中的圖像數據,用 predict() 函數進行分類。predict() 函數以 NumPy 數組的形式輸出各個標簽對應的概率。比如輸出 [0.1, 0.3, 0.2, ..., 0.04]的數組,該數組表示“0”的概率為 0.1,“1”的概率為 0.3,等等。然后,我們取出這個概率列表中的最大值的索引(第幾個元素的概率最高),作為預測結果。可以用 np.argmax(x) 函數取出數組中的最大值的索引,np.argmax(x) 將獲取被賦給參數 x 的數組中的最大值元素的索引。最后,比較神經網絡所預測的答案和正確解標簽,將回答正確的概率作為識別精度。

 

 

 

下面我們進行基於批處理的代碼實現。這里用粗體顯示與之前的實現的不同之處。

 1 x, t = get_data()
 2 network = init_network()
 3 
 4 batch_size = 100 # 批數量
 5 accuracy_cnt = 0
 6 
 7 for i in range(0, len(x), batch_size):  8     x_batch = x[i:i+batch_size]  9     y_batch = predict(network, x_batch) 10     p = np.argmax(y_batch, axis=1) 11     accuracy_cnt += np.sum(p == t[i:i+batch_size])

我們來逐個解釋粗體的代碼部分。首先是 range() 函數。range() 函數若指定為 range(start, end),則會生成一個由 start 到 end-1 之間的整數構成的列表。若像 range(start, end, step) 這樣指定 3 個整數,則生成的列表中的下一個元素會增加 step 指定的值。我們來看一個例子。

 

>>> list( range(0, 10) )
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
>>> list( range(0, 10, 3) )
[0, 3, 6, 9]

 range() 函數生成的列表的基礎上通過 x[i:i+batch_size] 從輸入數據中抽出批數據x[i:i+batch_n] 會取出從第 i 個到第 i+batch_n 個之間的數據。本例中是像 x[0:100]x[100:200]……這樣,從頭開始以 100 為單位將數據提取為批數據。

然后,通過 argmax() 獲取值最大的元素的索引。不過這里需要注意的是,我們給定了參數 axis=1。這指定了在 100 × 10 的數組中,沿着第 1 維方向(以第 1 維為軸)找到值最大的元素的索引(第 0 維對應第 1 個維度){9[矩陣的第 0 維是列方向,第 1 維是行方向。——譯者注]}。這里也來看一個例子。

 

>>> x = np.array([[0.1, 0.8, 0.1], [0.3, 0.1, 0.6],
...     [0.2, 0.5, 0.3], [0.8, 0.1, 0.1]])
>>> y = np.argmax(x, axis=1)
>>> print(y)
[1 2 1 0]

最后,我們比較一下以批為單位進行分類的結果和實際的答案。為此,需要在 NumPy 數組之間使用比較運算符(==)生成由 True/False 構成的布爾型數組,並計算 True 的個數。我們通過下面的例子進行確認。

>>> y = np.array([1, 2, 1, 0])
>>> t = np.array([1, 2, 0, 0])
>>> print(y==t)
[True True False True]
>>> np.sum(y==t)
3

仔細看過去,有很多內容之前不懂,看懂以后心里豁然開朗,不得不說真的是講得很棒,清晰又容易理解。
 

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM