神經網絡輸入層神經單元個數:784 (圖像大小28*28)
輸出層 :10 (10個類別分類,即10個數字)
隱藏層個數 :2
第1個隱藏層的神經單元數 :50
第2個隱藏層的神經單元數 :100
先定義get_data()、init_network()、predict()這3個函數:
1 def get_data(): 2 (x_train,t_train),(x_test,t_test)=load_mnist(nomalize=True,flatten=True,one_hot_label=False) 3 return x_test,t_test 4 5 def init_natwork(): 6 with open("sample_weight.pkl",'rb') as f: 7 network=pickle.load(f) 8 return network 9 10 def predict(network,x): 11 W1,W2,W3=network['W1'],network['W2'],network['W3'] 12 b1,b2,b3=network['b1'],network['b2'],network['b3'] 13 a1 = np.dot(x, W1) + b1 14 z1 = sigmoid(a1) 15 a2 = np.dot(z1, W2) + b2 16 z2 = sigmoid(a2) 17 a3 = np.dot(z2, W3) + b3 18 y = softmax(a3) 19 return y
init_network()
會讀入保存在 pickle 文件 sample_weight.pkl
中的學習到的權重參數 {8[因為之前我們假設學習已經完成,所以學習到的參數被保存下來。假設保存在 sample_weight.pkl
文件中,在推理階段,我們直接加載這些已經學習到的參數。——譯者注]}。這個文件中以字典變量的形式保存了權重和偏置參數。剩余的 2 個函數,和前面介紹的代碼實現基本相同,無需再解釋。現在,我們用這 3 個函數來實現神經網絡的推理處理。然后,評價它的識別精度(accuracy),即能在多大程度上正確分類。
1 x, t = get_data() 2 network = init_network() 3 4 accuracy_cnt = 0 5 for i in range(len(x)): 6 y = predict(network, x[i]) 7 p = np.argmax(y) # 獲取概率最高的元素的索引 8 if p == t[i]: 9 accuracy_cnt += 1 10 11 print("Accuracy:" + str(float(accuracy_cnt) / len(x)))
首先獲得 MNIST 數據集,生成網絡。接着,用 for
語句逐一取出保存在 x
中的圖像數據,用 predict()
函數進行分類。predict()
函數以 NumPy 數組的形式輸出各個標簽對應的概率。比如輸出 [0.1, 0.3, 0.2, ..., 0.04]
的數組,該數組表示“0”的概率為 0.1,“1”的概率為 0.3,等等。然后,我們取出這個概率列表中的最大值的索引(第幾個元素的概率最高),作為預測結果。可以用 np.argmax(x)
函數取出數組中的最大值的索引,np.argmax(x)
將獲取被賦給參數 x
的數組中的最大值元素的索引。最后,比較神經網絡所預測的答案和正確解標簽,將回答正確的概率作為識別精度。
下面我們進行基於批處理的代碼實現。這里用粗體顯示與之前的實現的不同之處。
1 x, t = get_data() 2 network = init_network() 3 4 batch_size = 100 # 批數量 5 accuracy_cnt = 0 6 7 for i in range(0, len(x), batch_size): 8 x_batch = x[i:i+batch_size] 9 y_batch = predict(network, x_batch) 10 p = np.argmax(y_batch, axis=1) 11 accuracy_cnt += np.sum(p == t[i:i+batch_size])
我們來逐個解釋粗體的代碼部分。首先是 range()
函數。range()
函數若指定為 range(start, end)
,則會生成一個由 start
到 end-1
之間的整數構成的列表。若像 range(start, end, step)
這樣指定 3 個整數,則生成的列表中的下一個元素會增加 step
指定的值。我們來看一個例子。
>>> list( range(0, 10) ) [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] >>> list( range(0, 10, 3) ) [0, 3, 6, 9]
在 range()
函數生成的列表的基礎上,通過 x[i:i+batch_size]
從輸入數據中抽出批數據。x[i:i+batch_n]
會取出從第 i
個到第 i+batch_n
個之間的數據。本例中是像 x[0:100]
、x[100:200]
……這樣,從頭開始以 100 為單位將數據提取為批數據。
然后,通過 argmax()
獲取值最大的元素的索引。不過這里需要注意的是,我們給定了參數 axis=1
。這指定了在 100 × 10 的數組中,沿着第 1 維方向(以第 1 維為軸)找到值最大的元素的索引(第 0 維對應第 1 個維度){9[矩陣的第 0 維是列方向,第 1 維是行方向。——譯者注]}。這里也來看一個例子。
>>> x = np.array([[0.1, 0.8, 0.1], [0.3, 0.1, 0.6], ... [0.2, 0.5, 0.3], [0.8, 0.1, 0.1]]) >>> y = np.argmax(x, axis=1) >>> print(y) [1 2 1 0]
最后,我們比較一下以批為單位進行分類的結果和實際的答案。為此,需要在 NumPy 數組之間使用比較運算符(==
)生成由True/False
構成的布爾型數組,並計算True
的個數。我們通過下面的例子進行確認。
>>> y = np.array([1, 2, 1, 0]) >>> t = np.array([1, 2, 0, 0]) >>> print(y==t) [True True False True] >>> np.sum(y==t) 3
仔細看過去,有很多內容之前不懂,看懂以后心里豁然開朗,不得不說真的是講得很棒,清晰又容易理解。