TensorFlow 入門之手寫識別(MNIST) softmax算法


TensorFlow 入門之手寫識別(MNIST) softmax算法

 

 

softmax回歸算法

我們知道MNIST的每一張圖片都表示一個數字,從0到9。我們希望得到給定圖片代表每個數字的概率。比如說,我們的模型可能推測一張包含9的圖片代表數字9的概率是80%但是判斷它是8的概率是5%(因為8和9都有上半部分的小圓),然后給予它代表其他數字的概率更小的值。
這是一個使用softmax回歸(softmax regression)模型的經典案例。 softmax 模型可以用來給不同的對象分配概率。即使在之后,我們訓練更加精細的模型時,最后一步也需要用softmax來分配概率。
這是一個使用softmax回歸(softmax regression)模型的經典案例。softmax模型可以用來給不同的對象分配概率。即使在之后,我們訓練更加精細的模型時,最后一步也需要用softmax來分配概率。

softmax回歸(softmax regression)分兩步:第一步

為了得到一張給定圖片屬於某個特定數字類的證據(evidence),我們對圖片像素值進行加權求和。如果這個像素具有很強的證據說明這張圖片不屬於該類,那么相應的權值為負數,相反如果這個像素擁有有利的證據支持這張圖片屬於這個類,那么權值是正數。

下面的圖片顯示了一個模型學習到的圖片上每個像素對於特定數字類的權值。紅色代表負數權值,藍色代表正數權值。

 


數字的特征

 

我們也需要加入一個額外的偏置量(bias),因為輸入往往會帶有一些無關的干擾量。因此對於給定的輸入圖片x它代表的是數字i的證據可以表示為

 


求和

 

其中Wi代表權重,bi 代表數字 i 類的偏置量,j 代表給定圖片 x 的像素索引用於像素求和。然后用softmax函數可以把這些證據轉換成概率 y:

 


激勵函數

 

這的softmax可是看做是一個sigmoid形式的函數。把我們定義的線性函數的輸出轉換成我們想要的格式,也就是關於10個數字類的概率分布。因此,給定一張圖片,它對於每一個數字的吻合度可以被softmax函數轉換成為一個概率值。

 


歸一化處理

 

展開等式右邊的子式,可以得到:

 


softmax使用的公式

 

對於softmax回歸模型可以用下面的圖解釋,對於輸入的xs加權求和,再分別加上一個偏置量,最后再輸入到softmax函數中:

 


softmax運行方式

 

如果把它寫成一個等式,我們可以得到:

 


softmax數學表達式

 

我們也可以用向量表示這個計算過程:用矩陣乘法和向量相加。這有助於提高計算效率。(也是一種更有效的思考方式):

 


softmax矩陣表現形式

 

更進一步,可以寫成更加緊湊的方式:

 


最終會使用的表達式

 

TensorFlow實現softmax

  1. # create a softmax regression 
  2.  
  3. import tensorflow as tf 
  4. from tensorflow.examples.tutorials.mnist import input_data 
  5.  
  6. mnist = input_data.read_data_sets("/home/fly/TensorFlow/mnist", one_hot=True
  7.  
  8. x = tf.placeholder(tf.float32,[None, 784]) 
  9.  
  10. W = tf.Variable(tf.zeros([784, 10])) 
  11.  
  12. b = tf.Variable(tf.zeros([10])) 
  13.  
  14. y = tf.nn.softmax(tf.matmul(x,W)+b) 
  15.  
  16. y_ = tf.placeholder(tf.float32,[None, 10]) 
  17.  
  18. cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1])) 
  19.  
  20. train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy) 
  21.  
  22. init = tf.initialize_all_variables() 
  23. sess = tf.Session() 
  24. sess.run(init) 
  25.  
  26. for i in range(1000): 
  27. batch_xs, batch_ys = mnist.train.next_batch(100
  28. sess.run(train_step, feed_dict = {x: batch_xs, y_: batch_ys}) 
  29.  
  30. correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(y_,1)) 
  31. accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 
  32. print(sess.run(accuracy, feed_dict={x:mnist.test.images, y_: mnist.test.labels})) 
  33.  

Fly
2016.6


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM