一起學TensorFlow---搭建最簡單的全連接網絡實現手寫數字識別(MNIST)


剛開始學Tensorflow,這里記錄學習中的點點滴滴,希望能和大家共同進步。

Cuda和Tensorflow的安裝請參考上一篇博客:http://www.cnblogs.com/roboai/p/7768191.html

Tensorflow簡單介紹

  我們知道,一維的數據可以用數組表示,二維可以用矩陣表示,那么三維或三維以上呢?比如圖像,實際上就是一個三維數據[h,w,c],高、寬、通道數,對於灰度圖來說,通道數為1,而對於彩色圖像,通道數為3。對於這種三維或三維以上的數據,我們稱之為張量(tensor),所以顧名思義,Tensorflow的意思就是張量的流動,Tensorflow將數據打包成一個個張量,由四個維度構成,分別是[batch, height, width, channels],然后在各個節點之間傳遞。

  節點是Tensorflow里另一重要的概念,對張量的操作稱之為節點,一系列的節點構成圖。接觸過Caffe的朋友可能發現了,這和Caffe里的blob、layer、net是一致的。不同的是,我們需要啟動一個會話來計算圖,這是Tensorflow的內在機制所決定的。Tensorflow依賴於一個高效的C++后端來進行計算,與后端的這個連接叫做session。一般而言,使用TensorFlow程序的流程是先創建一個圖,然后在session中啟動它。其思想是先讓我們描述一個交互操作圖,然后完全將其運行在Python外部。這樣做的目的是為了避免頻繁切換Python環境和外部環境時需要的開銷。如果你想在GPU或者分布式環境中計算時,這一開銷會非常可怖,這一開銷主要可能是用來進行數據遷移,並不能對計算做出貢獻。

  我們構建一個簡單的圖來說明以上過程,改圖包含三個節點(兩個源節點和一個矩陣乘法節點),然后啟動一個會話計算圖得到輸出結果,最后需要關閉會話。當然也可以使用with代碼塊實現自動關閉,效果是一樣的。

# coding=utf-8
import tensorflow as tf

# 該圖包含3個節點(兩個源節點和乘法節點)
matrix1 = tf.constant([[3, 3]])
matrix2 = tf.constant([[2], [2]])
product = tf.matmul(matrix1, matrix2)

# 調用會話啟動圖
sess = tf.Session()
result = sess.run(product)

# 輸出結果並關閉會話
print result
sess.close()

# 使用“with”代碼塊自動關閉, 該方法更簡潔
with tf.Session() as sess:
    result = sess.run(product)
    print result

輸出結果為

[[12]]
[[12]]

MNIST數據集

  MNIST是一個入門級的計算機視覺數據集,它包含各種手寫數字圖片,也包含每一張圖片對應的標簽,告訴我們這個是數字幾。新建一個get.sh文件,寫入以下內容,執行該文件就可以下載該數據集。下載下來的數據集被分成兩部分,60000行的訓練數據集和10000行的測試數據集。每一張圖片包含28X28個像素點,我們可以把圖片展開成一個向量,長度是 28x28 = 784。

#!/usr/bin/env sh
# This scripts downloads the mnist data and unzips it.

DIR="$( cd "$(dirname "$0")" ; pwd -P )"
cd "$DIR"

echo "Downloading..."

for fname in train-images-idx3-ubyte train-labels-idx1-ubyte t10k-images-idx3-ubyte t10k-labels-idx1-ubyte
do
    if [ ! -e $fname ]; then
        wget --no-check-certificate http://yann.lecun.com/exdb/mnist/${fname}.gz
    fi
done

Softmax Regression與Cross Entropy

  在本文中,我們將采用最簡單的網絡來預測輸入圖片中的數字,整個網絡僅由一個Softmax Regression構成,數學模型可以寫作\(y=softmax(Wx+b)\)。假設\(y'\)是實際分布,\(y\)是預測分布,Cross Entropy的定義是\(loss=\sum{y'\log{y}}\)。關於Softmax Regression的反向傳遞及Cross Entropy的物理含義請參考以下兩篇博客,這里就不展開寫了。

  http://ufldl.stanford.edu/wiki/index.php/Softmax%E5%9B%9E%E5%BD%92  

  http://blog.csdn.net/rtygbwwwerr/article/details/50778098

全連接網絡實現手寫數字識別

  下面終於進入正題了,我們有了數據集,同時也了解了算法流程,剩下的就是寫代碼實現了。首先是導入包,由於Tensorflow幫我們寫了一部分數據讀寫的程序,我們這里就直接用了。

# coding=utf-8
import tensorflow.examples.tutorials.mnist.input_data as input_data
import tensorflow as tf

# 導入數據, 強烈建議預先下載
mnist = input_data.read_data_sets("data/", one_hot=True)

  這里數據可以用我前面給出的get.sh下載,然后放入data文件夾目錄下,我之前是直接用input_data.read_data_sets("data/", one_hot=True)下載的,結果半天下載不下來,所以這里還是建議預先下載吧,用get.sh下載比較快。然后是程序的主要部分。

# 訓練集占位符:28*28=784
x = tf.placeholder(tf.float32, [None, 784])
# 初始化參數
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))
# 輸出結果
y = tf.nn.softmax(tf.matmul(x, W) + b)
# 真實值
y_ = tf.placeholder(tf.float32, [None, 10])
# 計算交叉熵
crossEntropy = -tf.reduce_sum(y_*tf.log(y))
# 訓練策略
trainStep = tf.train.GradientDescentOptimizer(0.01).minimize(crossEntropy)
# 初始化參數值
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)

# 開始訓練:循環訓練1000次
for i in range(1000):
    batchXs, batchYs = mnist.train.next_batch(100)
    sess.run(trainStep, feed_dict={x: batchXs, y_: batchYs})

# 評估模型
correctPrediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correctPrediction, tf.float32))
print sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels})

  這里用的是占位符的方式傳入數據,占位符的尺寸為[None, 784],這里的None表示此張量的第一個維度可以是任何長度的。

  權重值W和偏置量b使用Variable來表示,一個Variable代表一個可修改的張量,存在在Tensorflow的用於描述交互性操作的圖中。它們可以用於計算輸入值,也可以在計算中被修改。對於各種機器學習應用,一般都會有模型參數,都可以用Variable表示。在這里,我們都用全為零的張量來初始化Wb。

  只需要一行代碼就可以實現我們的模型y = tf.nn.softmax(tf.matmul(x, W) + b),同樣損失函數也只需要一行代碼crossEntropy = -tf.reduce_sum(y_*tf.log(y))。

  以0.01的學習速率,采用梯度下降法最小化交叉熵,對應的代碼為trainStep = tf.train.GradientDescentOptimizer(0.01).minimize(crossEntropy)。

  然后初始化參數並訓練,定義訓練次數為1000,每次隨機地選取100圖像進行計算。

  最后對得到的模型使用測試數據進行評估,評估結果表明精度達到0.9148(每次都不一樣,在91%左右徘徊)。

  至此,我們采用最簡單的一個全連接網絡實現了一個手寫數字識別的網絡,剩下的工作是將這個網絡及參數保存,采用自己的圖片進行識別,進一步感受這個網絡的效果,這一部分將在后續的工作中進行。同時我們可以說這個網絡過於簡單了,91%的識別效果也遠遠達不到我們的需求,如何進一步提高網絡的精度是我們關注的重點。

 


關於會話

  會話(session)提供在圖中執行操作的一些方法。一般的模式是:

  1.  建立會話,此時會生成一張空圖;
  2.  在會話中添加節點和邊,形成一張圖;
  3.  執行圖

  在調用Session對象的run()方法來執行圖時,傳入一些Tensor,這個過程叫填充(feed);返回的結果類型根據輸入的類型而定,這個過程叫取回(fetch)。

  會話是圖交互的橋梁,一個會話可以有多個圖,會話可以修改圖的結構,也可以往圖中注入數據進行計算。因此,會話主要由兩個API接口--Extend和Run。Extend操作是在Graph中添加節點和邊,Run操作是輸入計算的節點和填充必要的數據后,進行計算,並輸出運算結果。

關於節點與圖

  圖中的節點又稱為算子,它代表一個操作(Operation,op),一般用來表示施加的數學運算,也可以表示數據輸入(feed in)的起點以及輸出(push out)的終點,或者是讀取/寫入持久變量(persistent variable)的終點。

  如果不顯式添加一個默認圖,系統會自動設置一個全局的默認圖。所設置的默認圖,在模塊范圍內定義的節點都將默認加入默認圖中。

關於可視化

  可視化時,需要在程序中給必要的節點添加摘要(summary),摘要會收集該節點的數據,並標記上第幾步、時間戳等標識,寫入事件文件(event file)中。

模型存儲與加載

  TensorFLow的API提供了兩種方式存儲和加載模型:

  (1)生成檢查點文件,拓展名一般為.ckpt,通過tf.train.Saver.save()生成。它包含權重和程序中定義的變量,不包含圖結構。如果需要在另一個程序中使用,需要重新構建圖結構,並告訴TensorFlow如何處理這些權重。

  (2)生成圖協議文件,這是一個二進制文件,拓展名一般為.pb,用tf.train.write_graph()保存,只包含圖形結構,不包含權重,然后使用tf.import_graph_def()來加載圖形。

模型訓練之Momentum

  Momentum是模擬物理學中的動量的概念,更新時在一定程度上保留之前的更新方向,利用當前的批次再微調本次的更新參數,因此引入了一個新的變量v(速度),作為前幾次梯度的累加。因此,Momentum能夠改善訓練過程,在下降初期,前后梯度一致時,能夠加速學習;在下降的中后期,在局部最小值附近來回震盪時,能夠抑制震盪,加快收斂。

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM