MNIST手寫識別


  Demo俠可能是我等小白進階的必經之路了,如今在AI領域,我也是個研究Demo的小白。用了兩三天裝好環境,跑通Demo,自學Python語法,進而研究這個Demo。當然過程中查了很多資料,充分發揮了小白的主觀能動性,總算有一些收獲需要總結下。

  不多說,算法在代碼中,一切也都在代碼中。

 1 import os  2 os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
 3 
 4 #獲得數據集
 5 from tensorflow.examples.tutorials.mnist import input_data  6 mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)  7 
 8 import tensorflow as tf  9 
10 #輸入圖像數據占位符
11 x = tf.placeholder(tf.float32, [None, 784]) 12 
13 #權值和偏差
14 W = tf.Variable(tf.zeros([784, 10])) 15 b = tf.Variable(tf.zeros([10])) 16 
17 #使用softmax模型
18 y = tf.nn.softmax(tf.matmul(x, W) + b) 19 
20 #代價函數占位符
21 y_ = tf.placeholder(tf.float32, [None, 10]) 22 
23 #交叉熵評估代價
24 cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1])) 25 
26 #使用梯度下降算法優化:學習速率為0.5
27 train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy) 28 
29 #Session(交互方式)
30 sess = tf.InteractiveSession() 31 
32 #初始化變量
33 tf.global_variables_initializer().run() 34 
35 #訓練模型,訓練1000次
36 for _ in range(1000): 37   batch_xs, batch_ys = mnist.train.next_batch(100) 38   sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys}) 39 
40 #計算正確率
41 correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1)) 42 
43 accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 44 print(sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels}))

  看完這個Demo,頓時感覺Python真是一門好語言,Tensorflow是一個好框架,就跟之前掌握Matlab以后,用Matlab做仿真的感覺一樣。

  為什么看這幾行代碼看了兩三天,因為看懂很容易,但了解代碼背后的意義更重要,如果把一個Demo看透了,那么后邊舉一反三就會很容易了,我向來就是這樣學習的,本小白當年也是個學霸?!

  來一起看下這里邊有什么玄機和坑吧,記錄一下,人老了記性不好(^-^)。

  看到1,2行代碼,不要懵,這個作用是設置日志級別,os.environ["TF_CPP_MIN_LOG_LEVEL"]='2' # 只顯示 warning 和 Error,等於1是顯示所有信息。不加這兩行會有個提示(Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2,具體可以看這里 

  第5行是一個引用聲明,從tensorflow.examples.tutorials.mnist 引用一個名為 input_data 的函數,可以看一下input_data是什么樣子的:

 1 from __future__ import absolute_import  2 from __future__ import division  3 from __future__ import print_function  4 
 5 import gzip  6 import os  7 import tempfile  8 
 9 import numpy 10 from six.moves import urllib 11 from six.moves import xrange  # pylint: disable=redefined-builtin
12 import tensorflow as tf 13 from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets

  原來input_data里邊也是引用聲明,真正想用到的實際是tensorflow.contrib.learn.python.learn.datasets.mnist里的read_data_sets,看一下代碼:

 1 def read_data_sets(train_dir,  2                    fake_data=False,  3                    one_hot=False,  4                    dtype=dtypes.float32,  5                    reshape=True,  6                    validation_size=5000,  7                    seed=None,  8                    source_url=DEFAULT_SOURCE_URL):  9   if fake_data: 10  ... 11 
12   if not source_url:  # empty string check
13  ... 14 
15   local_file = base.maybe_download(TRAIN_IMAGES, train_dir, 16                                    source_url + TRAIN_IMAGES) 17   with gfile.Open(local_file, 'rb') as f: 18     train_images = extract_images(f) 19 
20  ... 21 
22   if not 0 <= validation_size <= len(train_images): 23     raise ValueError('Validation size should be between 0 and {}. Received: {}.'
24  .format(len(train_images), validation_size)) 25 
26   validation_images = train_images[:validation_size] 27   validation_labels = train_labels[:validation_size] 28   train_images = train_images[validation_size:] 29   train_labels = train_labels[validation_size:] 30 
31   options = dict(dtype=dtype, reshape=reshape, seed=seed) 32 
33   train = DataSet(train_images, train_labels, **options) 34   validation = DataSet(validation_images, validation_labels, **options) 35   test = DataSet(test_images, test_labels, **options) 36 
37   return base.Datasets(train=train, validation=validation, test=test)

  mnist最終得到的是base.Datasets,完成了數據讀取。這里邊的細節還需要完了再仔細研究下。

  順便記錄下自編的函數的定義方法:

 1 def Mycollect(My , thing):
 2 
 3     try:
 4         count = My[thing]
 5     except KeyError:
 6         count = 0
 7 
 8     return count
 9 
10 from TestFunction import Mycollect
11 My = {'a':10, 'b':15, 'c':5}
12 thing = 'a'
13 print(Mycollect(My , thing));

 

  第11行的placeholder,需要注意下,是用了占位符,也就是先安排位置,而不先提供具體數據,也就是說都是模型(管道)的構建過程(這里用管道來類比,我覺得比較恰當)。注意下placeholder的語法就可以,指定了type和shape,這里的None表示有多少幅圖片是未知的,也就是說樣本數是未知的。這里的坑在於,如果我們用print看的話會發現,構建的是張量(Tensor)而不是矩陣,這里對熟悉matlab的同學來說可能是個坑。可以注意下張量的定義方式。

  第14和15行是定義了變量,如果只看tf.zeros([10])的話也是個張量的,只是外邊又加了變量的聲明。所以后邊可以直接乘的,這個也不難理解了。

  第18行的matmul是張量相乘,然后使用了softmax模型,目的是把結果進行概率化。巧妙,只想說這兩個字,這個就是進行歸一化,搞算法這個是比較常用的,學校時候這個詞很火,我們最終想得到的是一個指定的數組,所以用這個模型來匹配我的規則。

  21行是什么,看完就知道是實際的輸出,然后在24行做交叉熵。終於又碰到熵這個老朋友了。交叉熵簡單理解為概率分布的距離,在這里作為一個loss_function。第27行使用了梯度下降來優化這個loss_function,最終是想找到最優時候的一個模型,這里的最優指的是通過這個模型,得到的結果和實際值最接近。

  第30行,創建一個session。

  第33行,初始化變量。

  第37行,可以去看下next_batch的源碼,作用是選取100個樣本來訓練。

  第41行,注意equal函數的作用,第43行來做類型轉換,然后取平均值。(代碼很巧妙,很優雅,很爽)

  最終第44行輸出模型的准確率。

  好了,這大概就是我的一點點總結了,算是入了個門,接下來我會更多的舉一反三,深入掌握其精髓,我會努力走得更遠。

  作為一個小白,我要繼續努力向大牛學習,吃飯去咯,下周再戰。

 

  


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM