剛開始學Tensorflow,這里記錄學習中的點點滴滴,希望能和大家共同進步。
Cuda和Tensorflow的安裝請參考上一篇博客:http://www.cnblogs.com/roboai/p/7768191.html
Tensorflow簡單介紹
我們知道,一維的數據可以用數組表示,二維可以用矩陣表示,那么三維或三維以上呢?比如圖像,實際上就是一個三維數據[h,w,c],高、寬、通道數,對於灰度圖來說,通道數為1,而對於彩色圖像,通道數為3。對於這種三維或三維以上的數據,我們稱之為張量(tensor),所以顧名思義,Tensorflow的意思就是張量的流動,Tensorflow將數據打包成一個個張量,由四個維度構成,分別是[batch, height, width, channels]
,然后在各個節點之間傳遞。
節點是Tensorflow里另一重要的概念,對張量的操作稱之為節點,一系列的節點構成圖。接觸過Caffe的朋友可能發現了,這和Caffe里的blob、layer、net是一致的。不同的是,我們需要啟動一個會話來計算圖,這是Tensorflow的內在機制所決定的。Tensorflow依賴於一個高效的C++后端來進行計算,與后端的這個連接叫做session。一般而言,使用TensorFlow程序的流程是先創建一個圖,然后在session中啟動它。其思想是先讓我們描述一個交互操作圖,然后完全將其運行在Python外部。這樣做的目的是為了避免頻繁切換Python環境和外部環境時需要的開銷。如果你想在GPU或者分布式環境中計算時,這一開銷會非常可怖,這一開銷主要可能是用來進行數據遷移,並不能對計算做出貢獻。
我們構建一個簡單的圖來說明以上過程,改圖包含三個節點(兩個源節點和一個矩陣乘法節點),然后啟動一個會話計算圖得到輸出結果,最后需要關閉會話。當然也可以使用with代碼塊實現自動關閉,效果是一樣的。
# coding=utf-8 import tensorflow as tf # 該圖包含3個節點(兩個源節點和乘法節點) matrix1 = tf.constant([[3, 3]]) matrix2 = tf.constant([[2], [2]]) product = tf.matmul(matrix1, matrix2) # 調用會話啟動圖 sess = tf.Session() result = sess.run(product) # 輸出結果並關閉會話 print result sess.close() # 使用“with”代碼塊自動關閉, 該方法更簡潔 with tf.Session() as sess: result = sess.run(product) print result
輸出結果為
[[12]]
[[12]]
MNIST數據集
MNIST是一個入門級的計算機視覺數據集,它包含各種手寫數字圖片,也包含每一張圖片對應的標簽,告訴我們這個是數字幾。新建一個get.sh文件,寫入以下內容,執行該文件就可以下載該數據集。下載下來的數據集被分成兩部分,60000行的訓練數據集和10000行的測試數據集。每一張圖片包含28X28個像素點,我們可以把圖片展開成一個向量,長度是 28x28 = 784。
#!/usr/bin/env sh # This scripts downloads the mnist data and unzips it. DIR="$( cd "$(dirname "$0")" ; pwd -P )" cd "$DIR" echo "Downloading..." for fname in train-images-idx3-ubyte train-labels-idx1-ubyte t10k-images-idx3-ubyte t10k-labels-idx1-ubyte do if [ ! -e $fname ]; then wget --no-check-certificate http://yann.lecun.com/exdb/mnist/${fname}.gz fi done
Softmax Regression與Cross Entropy
在本文中,我們將采用最簡單的網絡來預測輸入圖片中的數字,整個網絡僅由一個Softmax Regression構成,數學模型可以寫作\(y=softmax(Wx+b)\)。假設\(y'\)是實際分布,\(y\)是預測分布,Cross Entropy的定義是\(loss=\sum{y'\log{y}}\)。關於Softmax Regression的反向傳遞及Cross Entropy的物理含義請參考以下兩篇博客,這里就不展開寫了。
http://ufldl.stanford.edu/wiki/index.php/Softmax%E5%9B%9E%E5%BD%92
http://blog.csdn.net/rtygbwwwerr/article/details/50778098
全連接網絡實現手寫數字識別
下面終於進入正題了,我們有了數據集,同時也了解了算法流程,剩下的就是寫代碼實現了。首先是導入包,由於Tensorflow幫我們寫了一部分數據讀寫的程序,我們這里就直接用了。
# coding=utf-8 import tensorflow.examples.tutorials.mnist.input_data as input_data import tensorflow as tf # 導入數據, 強烈建議預先下載 mnist = input_data.read_data_sets("data/", one_hot=True)
這里數據可以用我前面給出的get.sh下載,然后放入data文件夾目錄下,我之前是直接用input_data.read_data_sets("data/", one_hot=True)下載的,結果半天下載不下來,所以這里還是建議預先下載吧,用get.sh下載比較快。然后是程序的主要部分。
# 訓練集占位符:28*28=784 x = tf.placeholder(tf.float32, [None, 784]) # 初始化參數 W = tf.Variable(tf.zeros([784, 10])) b = tf.Variable(tf.zeros([10])) # 輸出結果 y = tf.nn.softmax(tf.matmul(x, W) + b) # 真實值 y_ = tf.placeholder(tf.float32, [None, 10]) # 計算交叉熵 crossEntropy = -tf.reduce_sum(y_*tf.log(y)) # 訓練策略 trainStep = tf.train.GradientDescentOptimizer(0.01).minimize(crossEntropy) # 初始化參數值 init = tf.global_variables_initializer() sess = tf.Session() sess.run(init) # 開始訓練:循環訓練1000次 for i in range(1000): batchXs, batchYs = mnist.train.next_batch(100) sess.run(trainStep, feed_dict={x: batchXs, y_: batchYs}) # 評估模型 correctPrediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1)) accuracy = tf.reduce_mean(tf.cast(correctPrediction, tf.float32)) print sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels})
這里用的是占位符的方式傳入數據,占位符的尺寸為[None, 784],這里的None
表示此張量的第一個維度可以是任何長度的。
權重值W和偏置量b使用Variable來表示,
一個Variable
代表一個可修改的張量,存在在Tensorflow的用於描述交互性操作的圖中。它們可以用於計算輸入值,也可以在計算中被修改。對於各種機器學習應用,一般都會有模型參數,都可以用Variable
表示。在這里,我們都用全為零的張量來初始化W
和b。
只需要一行代碼就可以實現我們的模型y = tf.nn.softmax(tf.matmul(x, W) + b),同樣損失函數也只需要一行代碼crossEntropy = -tf.reduce_sum(y_*tf.log(y))。
以0.01的學習速率,采用梯度下降法最小化交叉熵,對應的代碼為trainStep = tf.train.GradientDescentOptimizer(0.01).minimize(crossEntropy)。
然后初始化參數並訓練,定義訓練次數為1000,每次隨機地選取100圖像進行計算。
最后對得到的模型使用測試數據進行評估,評估結果表明精度達到0.9148(每次都不一樣,在91%左右徘徊)。
至此,我們采用最簡單的一個全連接網絡實現了一個手寫數字識別的網絡,剩下的工作是將這個網絡及參數保存,采用自己的圖片進行識別,進一步感受這個網絡的效果,這一部分將在后續的工作中進行。同時我們可以說這個網絡過於簡單了,91%的識別效果也遠遠達不到我們的需求,如何進一步提高網絡的精度是我們關注的重點。
關於會話
會話(session)提供在圖中執行操作的一些方法。一般的模式是:
- 建立會話,此時會生成一張空圖;
- 在會話中添加節點和邊,形成一張圖;
- 執行圖
在調用Session對象的run()方法來執行圖時,傳入一些Tensor,這個過程叫填充(feed);返回的結果類型根據輸入的類型而定,這個過程叫取回(fetch)。
會話是圖交互的橋梁,一個會話可以有多個圖,會話可以修改圖的結構,也可以往圖中注入數據進行計算。因此,會話主要由兩個API接口--Extend和Run。Extend操作是在Graph中添加節點和邊,Run操作是輸入計算的節點和填充必要的數據后,進行計算,並輸出運算結果。
關於節點與圖
圖中的節點又稱為算子,它代表一個操作(Operation,op),一般用來表示施加的數學運算,也可以表示數據輸入(feed in)的起點以及輸出(push out)的終點,或者是讀取/寫入持久變量(persistent variable)的終點。
如果不顯式添加一個默認圖,系統會自動設置一個全局的默認圖。所設置的默認圖,在模塊范圍內定義的節點都將默認加入默認圖中。
關於可視化
可視化時,需要在程序中給必要的節點添加摘要(summary),摘要會收集該節點的數據,並標記上第幾步、時間戳等標識,寫入事件文件(event file)中。
模型存儲與加載
TensorFLow的API提供了兩種方式存儲和加載模型:
(1)生成檢查點文件,拓展名一般為.ckpt,通過tf.train.Saver.save()生成。它包含權重和程序中定義的變量,不包含圖結構。如果需要在另一個程序中使用,需要重新構建圖結構,並告訴TensorFlow如何處理這些權重。
(2)生成圖協議文件,這是一個二進制文件,拓展名一般為.pb,用tf.train.write_graph()保存,只包含圖形結構,不包含權重,然后使用tf.import_graph_def()來加載圖形。
模型訓練之Momentum
Momentum是模擬物理學中的動量的概念,更新時在一定程度上保留之前的更新方向,利用當前的批次再微調本次的更新參數,因此引入了一個新的變量v(速度),作為前幾次梯度的累加。因此,Momentum能夠改善訓練過程,在下降初期,前后梯度一致時,能夠加速學習;在下降的中后期,在局部最小值附近來回震盪時,能夠抑制震盪,加快收斂。