MINIST深度學習識別:python全連接神經網絡和pytorch LeNet CNN網絡訓練實現及比較(一)


版權聲明:本文為博主原創文章,歡迎轉載,並請注明出處。聯系方式:460356155@qq.com

全連接神經網絡是深度學習的基礎,理解它就可以掌握深度學習的核心概念:前向傳播、反向誤差傳遞、權重、學習率等。這里先用python創建模型,用minist作為數據集進行訓練。

定義3層神經網絡:輸入層節點28*28(對應minist圖片像素數)、隱藏層節點300、輸出層節點10(對應0-9個數字)。

網絡的激活函數采用sigmoid,網絡權重的初始化采用正態分布。

完整代碼如下:

  1 # -*- coding:utf-8 -*-
  2 
  3 u"""全連接神經網絡訓練學習MINIST"""
  4 
  5 __author__ = 'zhengbiqing 460356155@qq.com'
  6 
  7 
  8 import numpy
  9 import scipy.special
 10 import scipy.misc
 11 from PIL import Image
 12 import matplotlib.pyplot
 13 import pylab
 14 import datetime
 15 from random import shuffle
 16 
 17 
 18 #是否訓練網絡
 19 LEARN = True
 20 
 21 #是否保存網絡
 22 SAVE_PARA = False
 23 
 24 #網絡節點數
 25 INPUT = 784
 26 HIDDEN = 300
 27 OUTPUT = 10
 28 
 29 #學習率和訓練次數
 30 LR = 0.05
 31 EPOCH = 10
 32 
 33 #訓練數據集文件
 34 TRAIN_FILE = 'mnist_train.csv'
 35 TEST_FILE = 'mnist_test.csv'
 36 
 37 #網絡保存文件名
 38 WEIGHT_IH = "minist_fc_wih.npy"
 39 WEIGHT_HO = "minist_fc_who.npy"
 40 
 41 
 42 #神經網絡定義
 43 class NeuralNetwork:
 44     def __init__(self, inport_nodes, hidden_nodes, output_nodes, learnning_rate):
 45         #神經網絡輸入層、隱藏層、輸出層節點數
 46         self.inodes = inport_nodes
 47         self.hnodes = hidden_nodes
 48         self.onodes = output_nodes
 49 
 50         #神經網絡訓練學習率
 51         self.learnning_rate = learnning_rate
 52 
 53         #用均值為0,標准方差為連接數的-0.5次方的正態分布初始化權重
 54         #權重矩陣行列分別為hidden * input、 output * hidden,和ih、ho相反
 55         self.wih = numpy.random.normal(0.0, pow(self.hnodes, -0.5), (self.hnodes, self.inodes))
 56         self.who = numpy.random.normal(0.0, pow(self.onodes, -0.5), (self.onodes, self.hnodes))
 57 
 58         #sigmoid函數為激活函數
 59         self.active_fun = lambda x: scipy.special.expit(x)        
 60 
 61     #設置神經網絡權重,在加載已訓練的權重時調用
 62     def set_weight(self, wih, who):
 63         self.wih = wih
 64         self.who = who
 65 
 66     #前向傳播,根據輸入得到輸出
 67     def get_outputs(self, input_list):
 68         # 把list轉換為N * 1的矩陣,ndmin=2二維,T轉制
 69         inputs = numpy.array(input_list, ndmin=2).T
 70 
 71         # 隱藏層輸入 = W dot X,矩陣乘法
 72         hidden_inputs = numpy.dot(self.wih, inputs)
 73         hidden_outputs = self.active_fun(hidden_inputs)
 74 
 75         final_inputs = numpy.dot(self.who, hidden_outputs)
 76         final_outputs = self.active_fun(final_inputs)
 77 
 78         return inputs, hidden_outputs, final_outputs
 79 
 80     #網絡訓練,誤差計算,誤差反向分配更新網絡權重
 81     def train(self, input_list, target_list):
 82         inputs, hidden_outputs, final_outputs = self.get_outputs(input_list)
 83 
 84         targets = numpy.array(target_list, ndmin=2).T
 85 
 86         #誤差計算
 87         output_errors = targets - final_outputs
 88         hidden_errors = numpy.dot(self.who.T, output_errors)
 89 
 90         #連接權重更新
 91         self.who += numpy.dot(self.learnning_rate * output_errors * final_outputs * (1 - final_outputs), hidden_outputs.T)
 92         self.wih += numpy.dot(self.learnning_rate * hidden_errors * hidden_outputs * (1 - hidden_outputs), inputs.T)
 93         
 94 
 95 #圖像像素值變換
 96 def vals2input(vals):
 97     #[0,255]的圖像像素值轉換為i[0.01,1],以便sigmoid函數作非線性變換
 98     return (numpy.asfarray(vals) / 255.0 * 0.99) + 0.01
 99 
100 
101 '''
102 訓練網絡
103 train:是否訓練網絡,如果不訓練則直接加載已訓練得到的網絡權重
104 epoch:訓練次數
105 save:是否保存訓練結果,即網絡權重
106 '''
107 def net_train(train, epochs, save):
108     if train:
109         with open(TRAIN_FILE, 'r') as train_file:
110             train_list = train_file.readlines()
111 
112         for epoch in range(epochs):
113             #打亂訓練數據
114             shuffle(train_list)
115 
116             for data in train_list:
117                 all_vals = data.split(',')
118                 #圖像數據為0~255,轉換到0.01~1區間,以便激活函數更有效
119                 inputs = vals2input(all_vals[1:])
120 
121                 #標簽,正確的為0.99,其他為0.01
122                 targets = numpy.zeros(OUTPUT) + 0.01
123                 targets[int(all_vals[0])] = 0.99
124 
125                 net.train(inputs, targets)
126 
127             #每個epoch結束后用測試集檢查識別准確度
128             net_test(epoch)
129             print('')
130 
131         if save:
132             #保存連接權重
133             numpy.save(WEIGHT_IH, net.wih)
134             numpy.save(WEIGHT_HO, net.who)
135     else:
136         #不訓練直接加載已保存的權重
137         wih = numpy.load(WEIGHT_IH)
138         who = numpy.load(WEIGHT_HO)
139         net.set_weight(wih, who)
140 
141 
142 '''
143 用測試集檢查准確率
144 '''
145 def net_test(epoch):
146     with open(TEST_FILE, 'r') as test_file:
147         test_list = test_file.readlines()
148 
149     ok = 0
150     errlist = [0] * 10
151 
152     for data in test_list:
153         all_vals = data.split(',')
154         inputs = vals2input(all_vals[1:])
155         _, _, net_out = net.get_outputs(inputs)
156 
157         max = numpy.argmax(net_out)
158         if max == int(all_vals[0]):
159             ok += 1
160         else:
161             # 識別錯誤統計,每個數字識別錯誤計數
162             # print('target:', all_vals[0], 'net_out:', max)
163             errlist[int(all_vals[0])] += 1
164 
165     print('EPOCH: {epoch} score: {score}'.format(epoch=epoch, score = ok / len(test_list) * 100))
166     print('error list: ', errlist, ' total: ', sum(errlist))
167 
168 
169 #變換圖片的尺寸,保存變換后的圖片
170 def resize_img(filein, fileout, width, height, type):
171     img = Image.open(filein)
172     out = img.resize((width, height), Image.ANTIALIAS)
173     out.save(fileout, type)
174 
175 
176 #用訓練得到的網絡識別一個圖片文件
177 def img_test(img_file):
178     file_name_list = img_file.split('.')
179     file_name, file_type = file_name_list[0], file_name_list[1]
180     out_file = file_name + 'out' + '.' + file_type
181     resize_img(img_file, out_file, 28, 28, file_type)
182 
183     img_array = scipy.misc.imread(out_file, flatten=True)
184     img_data = 255.0 - img_array.reshape(784)
185     img_data = (img_data / 255.0 * 0.99) + 0.01
186 
187     _, _, net_out = net.get_outputs(img_data)
188     max = numpy.argmax(net_out)
189     print('pic recognized as: ', max)
190 
191 
192 #顯示數據集某個索引對應的圖片
193 def img_show(train, index):
194     file = TRAIN_FILE if train else TEST_FILE
195     with open(file, 'r') as test_file:
196         test_list = test_file.readlines()
197 
198     all_values = test_list[index].split(',')
199     print('number is: ', all_values[0])
200 
201     image_array = numpy.asfarray(all_values[1:]).reshape((28, 28))
202     matplotlib.pyplot.imshow(image_array, cmap='Greys', interpolation='None')
203     pylab.show()
204 
205 
206 start_time = datetime.datetime.now()
207 
208 net = NeuralNetwork(INPUT, HIDDEN, OUTPUT, LR)
209 net_train(LEARN, EPOCH, SAVE_PARA)
210 
211 if not LEARN:
212     net_test(0)
213 else:
214     print('MINIST FC Train:', INPUT, HIDDEN, OUTPUT, 'LR:', LR, 'EPOCH:', EPOCH)
215     print('train spend time: ', datetime.datetime.now() - start_time)
216 
217 #用畫圖軟件創建圖片文件,由得到的網絡進行識別
218 # img_test('t9.png')
219 
220 #顯示minist中的某個圖片
221 # img_show(True, 1)

784-300-10簡單的全連接神經網絡訓練結果准確率基本在97.7%左右,運行結果如下:

EPOCH: 0 score: 95.96000000000001
error list:  [13, 21, 31, 28, 51, 61, 33, 66, 44, 56]  total:  404

EPOCH: 1 score: 96.77
error list:  [15, 19, 27, 63, 37, 37, 21, 40, 18, 46]  total:  323

EPOCH: 2 score: 97.25
error list:  [9, 17, 26, 26, 24, 56, 21, 41, 22, 33]  total:  275

EPOCH: 3 score: 97.82
error list:  [9, 16, 21, 18, 20, 18, 22, 21, 31, 42]  total:  218

EPOCH: 4 score: 97.54
error list:  [12, 23, 17, 25, 15, 34, 19, 25, 22, 54]  total:  246

EPOCH: 5 score: 97.78999999999999
error list:  [10, 16, 20, 23, 21, 32, 18, 31, 26, 24]  total:  221

EPOCH: 6 score: 97.6
error list:  [9, 13, 26, 34, 27, 26, 20, 28, 22, 35]  total:  240

EPOCH: 7 score: 97.74000000000001
error list:  [12, 8, 26, 29, 27, 26, 25, 20, 27, 26]  total:  226

EPOCH: 8 score: 97.77
error list:  [7, 10, 27, 16, 29, 28, 23, 29, 26, 28]  total:  223

EPOCH: 9 score: 97.99
error list:  [11, 10, 32, 17, 18, 24, 14, 22, 21, 32]  total:  201

MINIST FC Train: 784 300 10 LR: 0.05 EPOCH: 10
train spend time:  0:05:54.137925

Process finished with exit code 0

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM