神經網絡的訓練和測試 python


  承接上一節,神經網絡需要訓練,那么訓練集來自哪?測試的數據又來自哪?

  《python神經網絡編程》一書給出了訓練集,識別圖片中的數字。測試集的鏈接如下:

  https://raw.githubusercontent.com/makeyourownneuralnetwork/makeyourownneuralnetwork/master/mnist_dataset/mnist_test_10.csv

為了方便,這只是一個小的測試集,才10個。

  訓練集鏈接:https://raw.githubusercontent.com/makeyourownneuralnetwork/makeyourownneuralnetwork/master/mnist_dataset/mnist_train_100.csv

這是包含100個數據的訓練集。

  訓練集和測試集的每段的第一個字母是期望的數字,每段剩余的文本是表示這個數字的像素集合,為784個數據。為了計算,我們要把文本轉化為數字進行存放。把第一個數據當作期望數據,剩余的784個數據當作輸入。因此輸入節點設為784個。輸出節點設為10個,因為要識別的是10個數據0到9。隱藏層節點選為100個,並沒有進行科學的計算。

  

 1 import numpy
 2 import scipy.special
 3 import matplotlib.pyplot as plt
 4 import pylab
 5 # 神經網絡類定義
 6 class NeuralNetwork():
 7     # 初始化神經網絡
 8     def __init__(self, inputnodes, hiddennodes, outputnodes, learningrate):
 9         # 設置輸入層節點,隱藏層節點和輸出層節點的數量
10         self.inodes = inputnodes
11         self.hnodes = hiddennodes
12         self.onodes = outputnodes
13         # 學習率設置
14         self.lr = learningrate
15         # 權重矩陣設置 正態分布
16         self.wih = numpy.random.normal(0.0, pow(self.hnodes, -0.5), (self.hnodes, self.inodes))
17         self.who = numpy.random.normal(0.0, pow(self.onodes, -0.5), (self.onodes, self.hnodes))
18         # 激活函數設置,sigmod()函數
19         self.activation_function = lambda x: scipy.special.expit(x)
20         pass
21 
22     # 訓練神經網絡
23     def train(self,input_list,target_list):
24         # 轉換輸入輸出列表到二維數組
25         inputs = numpy.array(input_list, ndmin=2).T
26         targets = numpy.array(target_list,ndmin= 2).T
27         # 計算到隱藏層的信號
28         hidden_inputs = numpy.dot(self.wih, inputs)
29         # 計算隱藏層輸出的信號
30         hidden_outputs = self.activation_function(hidden_inputs)
31         # 計算到輸出層的信號
32         final_inputs = numpy.dot(self.who, hidden_outputs)
33         final_outputs = self.activation_function(final_inputs)
34 
35         output_errors = targets - final_outputs
36         hidden_errors = numpy.dot(self.who.T,output_errors)
37 
38         #隱藏層和輸出層權重更新
39         self.who += self.lr * numpy.dot((output_errors*final_outputs*(1.0-final_outputs)),
40                                         numpy.transpose(hidden_outputs))
41         #輸入層和隱藏層權重更新
42         self.wih += self.lr * numpy.dot((hidden_errors * hidden_outputs * (1.0 - hidden_outputs)),
43                                         numpy.transpose(inputs))
44         pass
45     # 查詢神經網絡
46     def query(self, input_list):
47         # 轉換輸入列表到二維數組
48         inputs = numpy.array(input_list, ndmin=2).T
49         # 計算到隱藏層的信號
50         hidden_inputs = numpy.dot(self.wih, inputs)
51         # 計算隱藏層輸出的信號
52         hidden_outputs = self.activation_function(hidden_inputs)
53         # 計算到輸出層的信號
54         final_inputs = numpy.dot(self.who, hidden_outputs)
55         final_outputs = self.activation_function(final_inputs)
56 
57         return final_outputs
58 
59 # 設置每層節點個數
60 input_nodes = 784
61 hidden_nodes = 100
62 output_nodes = 10
63 # 設置學習率為0.3
64 learning_rate = 0.3
65 # 創建神經網絡
66 n = NeuralNetwork(input_nodes, hidden_nodes, output_nodes, learning_rate)
67 
68 #讀取訓練數據集 轉化為列表
69 training_data_file = open("D:/mnist_train_100.csv",'r')
70 training_data_list = training_data_file.readlines();
71 training_data_file.close()
72 
73 #訓練神經網絡
74 for record in training_data_list:
75     #根據逗號,將文本數據進行拆分
76     all_values = record.split(',')
77     #將文本字符串轉化為實數,並創建這些數字的數組。
78     inputs = (numpy.asfarray(all_values[1:])/255.0 * 0.99) + 0.01
79     #創建用零填充的數組,數組的長度為output_nodes,加0.01解決了0輸入造成的問題
80     targets = numpy.zeros(output_nodes) + 0.01
81     #使用目標標簽,將正確元素設置為0.99
82     targets[int(all_values[0])] = 0.99
83     n.train(inputs,targets)
84     pass
85 
86 #讀取測試文件
87 test_data_file = open("D:/mnist_test_10.csv",'r')
88 test_data_list = test_data_file.readlines()
89 test_data_file.close()
90 
91 all_values = test_data_list[0].split(',')
92 print(all_values[0])  #輸出目標值
93 
94 image_array = numpy.asfarray(all_values[1:]).reshape((28,28))
95 print(n.query((numpy.asfarray(all_values[1:])/255.0*0.99)+0.01))#輸出標簽值
96 plt.imshow(image_array,cmap='Greys',interpolation='None')#顯示原圖像
97 pylab.show()

輸出情況:

  從結果可以看出,我們輸入的目標值為7,結果中第7個標簽所對應的值最大,表明了正確識別了目標值。和圖片中的值一樣。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM