版權聲明:本文為博主原創文章,歡迎轉載,並請注明出處。聯系方式:460356155@qq.com
全連接神經網絡是深度學習的基礎,理解它就可以掌握深度學習的核心概念:前向傳播、反向誤差傳遞、權重、學習率等。這里先用python創建模型,用minist作為數據集進行訓練。
定義3層神經網絡:輸入層節點28*28(對應minist圖片像素數)、隱藏層節點300、輸出層節點10(對應0-9個數字)。
網絡的激活函數采用sigmoid,網絡權重的初始化采用正態分布。
完整代碼如下:
1 # -*- coding:utf-8 -*- 2 3 u"""全連接神經網絡訓練學習MINIST""" 4 5 __author__ = 'zhengbiqing 460356155@qq.com' 6 7 8 import numpy 9 import scipy.special 10 import scipy.misc 11 from PIL import Image 12 import matplotlib.pyplot 13 import pylab 14 import datetime 15 from random import shuffle 16 17 18 #是否訓練網絡 19 LEARN = True 20 21 #是否保存網絡 22 SAVE_PARA = False 23 24 #網絡節點數 25 INPUT = 784 26 HIDDEN = 300 27 OUTPUT = 10 28 29 #學習率和訓練次數 30 LR = 0.05 31 EPOCH = 10 32 33 #訓練數據集文件 34 TRAIN_FILE = 'mnist_train.csv' 35 TEST_FILE = 'mnist_test.csv' 36 37 #網絡保存文件名 38 WEIGHT_IH = "minist_fc_wih.npy" 39 WEIGHT_HO = "minist_fc_who.npy" 40 41 42 #神經網絡定義 43 class NeuralNetwork: 44 def __init__(self, inport_nodes, hidden_nodes, output_nodes, learnning_rate): 45 #神經網絡輸入層、隱藏層、輸出層節點數 46 self.inodes = inport_nodes 47 self.hnodes = hidden_nodes 48 self.onodes = output_nodes 49 50 #神經網絡訓練學習率 51 self.learnning_rate = learnning_rate 52 53 #用均值為0,標准方差為連接數的-0.5次方的正態分布初始化權重 54 #權重矩陣行列分別為hidden * input、 output * hidden,和ih、ho相反 55 self.wih = numpy.random.normal(0.0, pow(self.hnodes, -0.5), (self.hnodes, self.inodes)) 56 self.who = numpy.random.normal(0.0, pow(self.onodes, -0.5), (self.onodes, self.hnodes)) 57 58 #sigmoid函數為激活函數 59 self.active_fun = lambda x: scipy.special.expit(x) 60 61 #設置神經網絡權重,在加載已訓練的權重時調用 62 def set_weight(self, wih, who): 63 self.wih = wih 64 self.who = who 65 66 #前向傳播,根據輸入得到輸出 67 def get_outputs(self, input_list): 68 # 把list轉換為N * 1的矩陣,ndmin=2二維,T轉制 69 inputs = numpy.array(input_list, ndmin=2).T 70 71 # 隱藏層輸入 = W dot X,矩陣乘法 72 hidden_inputs = numpy.dot(self.wih, inputs) 73 hidden_outputs = self.active_fun(hidden_inputs) 74 75 final_inputs = numpy.dot(self.who, hidden_outputs) 76 final_outputs = self.active_fun(final_inputs) 77 78 return inputs, hidden_outputs, final_outputs 79 80 #網絡訓練,誤差計算,誤差反向分配更新網絡權重 81 def train(self, input_list, target_list): 82 inputs, hidden_outputs, final_outputs = self.get_outputs(input_list) 83 84 targets = numpy.array(target_list, ndmin=2).T 85 86 #誤差計算 87 output_errors = targets - final_outputs 88 hidden_errors = numpy.dot(self.who.T, output_errors) 89 90 #連接權重更新 91 self.who += numpy.dot(self.learnning_rate * output_errors * final_outputs * (1 - final_outputs), hidden_outputs.T) 92 self.wih += numpy.dot(self.learnning_rate * hidden_errors * hidden_outputs * (1 - hidden_outputs), inputs.T) 93 94 95 #圖像像素值變換 96 def vals2input(vals): 97 #[0,255]的圖像像素值轉換為i[0.01,1],以便sigmoid函數作非線性變換 98 return (numpy.asfarray(vals) / 255.0 * 0.99) + 0.01 99 100 101 ''' 102 訓練網絡 103 train:是否訓練網絡,如果不訓練則直接加載已訓練得到的網絡權重 104 epoch:訓練次數 105 save:是否保存訓練結果,即網絡權重 106 ''' 107 def net_train(train, epochs, save): 108 if train: 109 with open(TRAIN_FILE, 'r') as train_file: 110 train_list = train_file.readlines() 111 112 for epoch in range(epochs): 113 #打亂訓練數據 114 shuffle(train_list) 115 116 for data in train_list: 117 all_vals = data.split(',') 118 #圖像數據為0~255,轉換到0.01~1區間,以便激活函數更有效 119 inputs = vals2input(all_vals[1:]) 120 121 #標簽,正確的為0.99,其他為0.01 122 targets = numpy.zeros(OUTPUT) + 0.01 123 targets[int(all_vals[0])] = 0.99 124 125 net.train(inputs, targets) 126 127 #每個epoch結束后用測試集檢查識別准確度 128 net_test(epoch) 129 print('') 130 131 if save: 132 #保存連接權重 133 numpy.save(WEIGHT_IH, net.wih) 134 numpy.save(WEIGHT_HO, net.who) 135 else: 136 #不訓練直接加載已保存的權重 137 wih = numpy.load(WEIGHT_IH) 138 who = numpy.load(WEIGHT_HO) 139 net.set_weight(wih, who) 140 141 142 ''' 143 用測試集檢查准確率 144 ''' 145 def net_test(epoch): 146 with open(TEST_FILE, 'r') as test_file: 147 test_list = test_file.readlines() 148 149 ok = 0 150 errlist = [0] * 10 151 152 for data in test_list: 153 all_vals = data.split(',') 154 inputs = vals2input(all_vals[1:]) 155 _, _, net_out = net.get_outputs(inputs) 156 157 max = numpy.argmax(net_out) 158 if max == int(all_vals[0]): 159 ok += 1 160 else: 161 # 識別錯誤統計,每個數字識別錯誤計數 162 # print('target:', all_vals[0], 'net_out:', max) 163 errlist[int(all_vals[0])] += 1 164 165 print('EPOCH: {epoch} score: {score}'.format(epoch=epoch, score = ok / len(test_list) * 100)) 166 print('error list: ', errlist, ' total: ', sum(errlist)) 167 168 169 #變換圖片的尺寸,保存變換后的圖片 170 def resize_img(filein, fileout, width, height, type): 171 img = Image.open(filein) 172 out = img.resize((width, height), Image.ANTIALIAS) 173 out.save(fileout, type) 174 175 176 #用訓練得到的網絡識別一個圖片文件 177 def img_test(img_file): 178 file_name_list = img_file.split('.') 179 file_name, file_type = file_name_list[0], file_name_list[1] 180 out_file = file_name + 'out' + '.' + file_type 181 resize_img(img_file, out_file, 28, 28, file_type) 182 183 img_array = scipy.misc.imread(out_file, flatten=True) 184 img_data = 255.0 - img_array.reshape(784) 185 img_data = (img_data / 255.0 * 0.99) + 0.01 186 187 _, _, net_out = net.get_outputs(img_data) 188 max = numpy.argmax(net_out) 189 print('pic recognized as: ', max) 190 191 192 #顯示數據集某個索引對應的圖片 193 def img_show(train, index): 194 file = TRAIN_FILE if train else TEST_FILE 195 with open(file, 'r') as test_file: 196 test_list = test_file.readlines() 197 198 all_values = test_list[index].split(',') 199 print('number is: ', all_values[0]) 200 201 image_array = numpy.asfarray(all_values[1:]).reshape((28, 28)) 202 matplotlib.pyplot.imshow(image_array, cmap='Greys', interpolation='None') 203 pylab.show() 204 205 206 start_time = datetime.datetime.now() 207 208 net = NeuralNetwork(INPUT, HIDDEN, OUTPUT, LR) 209 net_train(LEARN, EPOCH, SAVE_PARA) 210 211 if not LEARN: 212 net_test(0) 213 else: 214 print('MINIST FC Train:', INPUT, HIDDEN, OUTPUT, 'LR:', LR, 'EPOCH:', EPOCH) 215 print('train spend time: ', datetime.datetime.now() - start_time) 216 217 #用畫圖軟件創建圖片文件,由得到的網絡進行識別 218 # img_test('t9.png') 219 220 #顯示minist中的某個圖片 221 # img_show(True, 1)
784-300-10簡單的全連接神經網絡訓練結果准確率基本在97.7%左右,運行結果如下:
EPOCH: 0 score: 95.96000000000001
error list: [13, 21, 31, 28, 51, 61, 33, 66, 44, 56] total: 404
EPOCH: 1 score: 96.77
error list: [15, 19, 27, 63, 37, 37, 21, 40, 18, 46] total: 323
EPOCH: 2 score: 97.25
error list: [9, 17, 26, 26, 24, 56, 21, 41, 22, 33] total: 275
EPOCH: 3 score: 97.82
error list: [9, 16, 21, 18, 20, 18, 22, 21, 31, 42] total: 218
EPOCH: 4 score: 97.54
error list: [12, 23, 17, 25, 15, 34, 19, 25, 22, 54] total: 246
EPOCH: 5 score: 97.78999999999999
error list: [10, 16, 20, 23, 21, 32, 18, 31, 26, 24] total: 221
EPOCH: 6 score: 97.6
error list: [9, 13, 26, 34, 27, 26, 20, 28, 22, 35] total: 240
EPOCH: 7 score: 97.74000000000001
error list: [12, 8, 26, 29, 27, 26, 25, 20, 27, 26] total: 226
EPOCH: 8 score: 97.77
error list: [7, 10, 27, 16, 29, 28, 23, 29, 26, 28] total: 223
EPOCH: 9 score: 97.99
error list: [11, 10, 32, 17, 18, 24, 14, 22, 21, 32] total: 201
MINIST FC Train: 784 300 10 LR: 0.05 EPOCH: 10
train spend time: 0:05:54.137925
Process finished with exit code 0