MNIST數據集入門


簡單的訓練MNIST數據集 (0-9的數字圖片)

詳細地址(包括下載地址):http://www.tensorfly.cn/tfdoc/tutorials/mnist_beginners.html

 

# -*- coding: utf-8 -*-
import tensorflow as tf
import numpy as np
import input_data  # 需要下載數據集(包括了input_data)
# 加載數據集 mnist
= input_data.read_data_sets("MNIST_data/", one_hot=True) # minist用來獲取批處理數據 # x: 任意數量的MNIST圖像,每一張圖展平成784維的向量。我們用2維的浮點數張量來表示這些 # 圖,這個張量的形狀是[None,784 ]。(這里的None表示此張量的第一個維度可以是任何 # 長度 batch取批量的大小 x圖片的數量。) x = tf.placeholder("float", shap=[None, 784]) # placeholdershape參數是可選的,但有了它,TensorFlow能夠自動捕捉因數據維度不一致導致的錯誤。 # 圖片設為“xs”,把這些標簽設為“ys” # softmax模型可以用來給不同的對象分配概率 W = tf.Variable(tf.zeros([784, 10])) # 28*28, 0-9 b = tf.Variable(tf.zeros([10])) # 0-9 # 構建模型 y = tf.nn.softmax(tf.matmul(x, W) + b) # y概率 # 訓練構建的模型 # 先定義指標評估模型好壞(指標稱為 成本cost,損失loss。小化這個指標) # 成本函數“交叉熵”cross-entropy。 # 計算交叉熵 需要添加新的占位符 y_: 實際分布one-hot [1,0,0,0,0,0,0,0,0,0] ?? y_ = tf.placeholder("float", [None, 10]) # 交叉熵 cross_entropy = -tf.reduce_sum(y_ * tf.log(y)) # tf的優化算法,根據交叉熵降低指標(成本,損失) # 梯度算法,0.01的學習率不斷地最小化交叉熵(指標) train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy) # 運行模型前,初始化創建的變量 init = tf.initialize_all_variables() # 啟動init sess = tf.Session() sess.run(init) # 開始訓練模型1000次 for i in range(1000): # 獲得100個批處理數據點 batch_xs, batch_ys = mnist.train.next_batch(100) # 進行梯度算法 sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys}) # 評估模型tf.argmax(x, 1) # 給出某個tensor對象在某一維上的其數據最大值所在的索引值。由於標簽向量是由0,1組 # 成,因此最大值1所在的索引位置就是類別標簽,比如tf.argmax(y,1)返回的是模型對於 # 任一輸入x預測到的標簽值,而 tf.argmax(y_,1) 代表正確的標簽,我們可以用 # tf.equal 來檢測我們的預測是否真實標簽匹配(索引位置一樣表示匹配)。 current_prediction = tf.equal(tf.argmax(y, 1), tf.arg_max(y_, 1)) # 其結果為bool值 [True, False, ...] # 為了確定正確預測項的比例,我們可以把布爾值轉換成浮點數,然后取平均值 accuracy = tf.reduce_mean(tf.cast(current_prediction, "float")) # 運行accuracy print(sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels})) # 結果約為 91% 左右

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM