這是個分類應用入門:使用softmax分類,簡單來說就是概論轉化為0-1區間的一個數字
讀取數據集
1 # 導入相關庫 2 import tensorflow as tf 3 from tensorflow.examples.tutorials.mnist import input_data 4 mnist=input_data.read_data_sets("D:/MNIST",one_hot=True)
獨熱編碼(one hot encoding)
一種稀疏向量,其中:一個元素設為1,所有其他元素均設為0
獨熱編碼常用於表示擁有有限個可能值的字符串或標識符
工作流程:
1 import tensorflow as tf 2 import matplotlib.pyplot as plt 3 import numpy as np 4 import tensorflow.examples.tutorials.mnist.input_data as input_data 5 mnist=input_data.read_data_sets("MNIST_data",one_hot=True) #讀取數據 6 import os #可加可不加,屏蔽通知消息 7 os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 8 9 print('訓練集 train 數量:',mnist.train.num_examples, 10 ',驗證集 validation 數量:',mnist.validation.num_examples, 11 ',測試集 test 數量:',mnist.test.num_examples) 12 13 # print('train images shape:',mnist.train.images.shape, 14 # 'labels shaple:',mnist.train.labels.shape) 15 16 # print((mnist.train.images[0].reshape(28,28))) 17 # print(len(mnist.train.images[0].shape)) 18 19 # def plot_image(image): 20 # plt.imshow(image.reshape(28,28)) 21 # plt.show() 22 # 23 # plot_image(mnist.train.images[1]) 24 # plt.imshow(mnist.train.images[20000].reshape(14,56)) 25 # plt.show() 26 27 # print(mnist.train.labels[1]) 28 # print(np.argmax(mnist.train.labels[1])) 29 # mnist_no_one_hot=input_data.read_data_sets("MNIST_data",one_hot=False) 30 # print(mnist_no_one_hot.train.labels[0:10]) 31 # 32 # print('validation images:',mnist.validation.images.shape,'labels:',mnist.validation.labels.shape) 33 # 34 # print('test images:',mnist.test.images.shape,'labels:',mnist.test.labels.shape) 35 # 36 # batch_images_xs,batch_labels_ys=mnist.train.next_batch(batch_size=10) 37 # print(mnist.train.labels[0:10]) 38 # # print(batch_labels_ys) 39 40 # mnist中每張圖片共有28*28=784個像素點 41 x=tf.placeholder(tf.float32,[None,784]) 42 # 0-9一共10個數字->10個類別 43 y=tf.placeholder(tf.float32,[None,10]) 44 45 # 定義模型變量(以正態分布的隨機數初始化權重W,以常數0初始化偏置b) 46 W=tf.Variable(tf.random_normal([784,10],mean=0.0,stddev=1.0)) 47 b=tf.Variable(tf.zeros([10])) 48 49 # 前向計算 50 forward=tf.matmul(x,W)+b 51 #softmax分類 52 pred=tf.nn.softmax(forward) 53 # 定義交叉熵損失函數 54 loss_function=tf.reduce_mean(-tf.reduce_sum(y*tf.log(pred),reduction_indices=1)) 55 56 # 設置訓練參數 57 train_epochs=150 # 訓練輪數 58 batch_size=50 # 單次訓練樣本數(批次大小) 59 total_batch=int(mnist.train.num_examples/batch_size) # 一輪訓練的批次數 60 display_step=1 # 顯示粒度 61 learning_rate=0.04 # 學習率 62 63 # 分類模型構建與訓練實踐 64 #選擇優化器,梯度下降 65 optimizer=tf.train.GradientDescentOptimizer(learning_rate).minimize(loss_function) 66 67 # 定義准確率,檢查預測類別tf.argmax(pred,1)與實際類別tf.argmax(y,1)的匹配情況 68 correct_prediction=tf.equal(tf.argmax(pred,1),tf.argmax(y,1)) 69 # 准確率,將布爾值轉化為浮點數,並計算平均值 70 accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32)) 71 72 # 聲明會話 73 sess=tf.Session() 74 init=tf.global_variables_initializer() 75 sess.run(init) 76 77 # 訓練模型 78 for epoch in range(train_epochs): 79 for batch in range(total_batch): 80 xs, ys = mnist.train.next_batch(batch_size) # 讀取批次數據 81 sess.run(optimizer, feed_dict={x: xs, y: ys}) # 執行批次訓練 82 83 # total_batch批次訓練完成之后,使用驗證數據計算誤差與准確率,驗證集沒有分批。 84 loss, acc = sess.run([loss_function, accuracy],feed_dict={x: mnist.validation.images, y: mnist.validation.labels}) 85 86 # 打印訓練過程中的詳細信息 87 if (epoch + 1)% display_step==0: 88 print("train_epoch:", '%02d' % (epoch + 1), "loss=", "{:.9f}".format(loss),"accuracy=", '{:.4f}'.format(acc)) 89 print("train finished!") 90 91 # 在測試集上評估模型准確率 92 accu_test=sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels}) 93 print("test accuracy:",accu_test) 94 95 # 在驗證集上評估模型准確率 96 accu_validation=sess.run(accuracy,feed_dict={x:mnist.validation.images,y:mnist.validation.labels}) 97 print("validatin accuracy:",accu_validation) 98 99 # 在訓練集上評估模型准確率 100 accu_train=sess.run(accuracy,feed_dict={x:mnist.train.images,y:mnist.train.labels}) 101 print("tarin accuracy:",accu_train) 102 103 # 由於pred預測結果是one-hot編碼格式,所以需要轉化為0~9數字 104 prediction_result=sess.run(tf.argmax(pred,1),feed_dict={x:mnist.test.images}) 105 106 # 查看結果中的前十項 107 prediction_result[0:10] 108 # 定義可視化函數 109 110 def plt_images_labels_prediction(images, # 圖像列表 111 labels, # 標簽列表 112 prediction, # 預測值列表 113 index, # 從第index個開始顯示 114 num=10): # 缺省依次顯示10副 115 fig = plt.gcf() # 獲取當前圖表,get current figure 116 fig.set_size_inches(10, 12) # 1英寸等於2.54cm 117 if num > 25: 118 num = 25 # 最多顯示25個子圖 119 for i in range(0, num): 120 ax = plt.subplot(5, 5, i + 1) # 獲取當前要處理的子圖 121 122 ax.imshow(np.reshape(images[index], (28, 28)), 123 cmap='binary') # 顯示第index個圖像 124 title = "labels=" + str(np.argmax(labels[index])) # 構建該圖上要顯示的title信息 125 if len(prediction) > 0: 126 title += ",predict=" + str(prediction[index]) 127 128 ax.set_title(title) # 顯示圖上的title 129 ax.set_xticks([]) # 不顯示坐標軸 130 ax.set_yticks([]) 131 index += 1 132 plt.show() 133 # 可視化預測結果 134 plt_images_labels_prediction(mnist.test.images, 135 mnist.test.labels, 136 prediction_result,10,25)
大概就結束了,相當於機器學習的一個helloworld