Demo俠可能是我等小白進階的必經之路了,如今在AI領域,我也是個研究Demo的小白。用了兩三天裝好環境,跑通Demo,自學Python語法,進而研究這個Demo。當然過程中查了很多資料,充分發揮了小白的主觀能動性,總算有一些收獲需要總結下。
不多說,算法在代碼中,一切也都在代碼中。
1 import os 2 os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
3
4 #獲得數據集
5 from tensorflow.examples.tutorials.mnist import input_data 6 mnist = input_data.read_data_sets("MNIST_data/", one_hot=True) 7
8 import tensorflow as tf 9
10 #輸入圖像數據占位符
11 x = tf.placeholder(tf.float32, [None, 784]) 12
13 #權值和偏差
14 W = tf.Variable(tf.zeros([784, 10])) 15 b = tf.Variable(tf.zeros([10])) 16
17 #使用softmax模型
18 y = tf.nn.softmax(tf.matmul(x, W) + b) 19
20 #代價函數占位符
21 y_ = tf.placeholder(tf.float32, [None, 10]) 22
23 #交叉熵評估代價
24 cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1])) 25
26 #使用梯度下降算法優化:學習速率為0.5
27 train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy) 28
29 #Session(交互方式)
30 sess = tf.InteractiveSession() 31
32 #初始化變量
33 tf.global_variables_initializer().run() 34
35 #訓練模型,訓練1000次
36 for _ in range(1000): 37 batch_xs, batch_ys = mnist.train.next_batch(100) 38 sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys}) 39
40 #計算正確率
41 correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1)) 42
43 accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 44 print(sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels}))
看完這個Demo,頓時感覺Python真是一門好語言,Tensorflow是一個好框架,就跟之前掌握Matlab以后,用Matlab做仿真的感覺一樣。
為什么看這幾行代碼看了兩三天,因為看懂很容易,但了解代碼背后的意義更重要,如果把一個Demo看透了,那么后邊舉一反三就會很容易了,我向來就是這樣學習的,本小白當年也是個學霸?!
來一起看下這里邊有什么玄機和坑吧,記錄一下,人老了記性不好(^-^)。
看到1,2行代碼,不要懵,這個作用是設置日志級別,os.environ["TF_CPP_MIN_LOG_LEVEL"]='2' # 只顯示 warning 和 Error,等於1是顯示所有信息。不加這兩行會有個提示(Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2,具體可以看這里)
第5行是一個引用聲明,從tensorflow.examples.tutorials.mnist 引用一個名為 input_data 的函數,可以看一下input_data是什么樣子的:
1 from __future__ import absolute_import 2 from __future__ import division 3 from __future__ import print_function 4
5 import gzip 6 import os 7 import tempfile 8
9 import numpy 10 from six.moves import urllib 11 from six.moves import xrange # pylint: disable=redefined-builtin
12 import tensorflow as tf 13 from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets
原來input_data里邊也是引用聲明,真正想用到的實際是tensorflow.contrib.learn.python.learn.datasets.mnist里的read_data_sets,看一下代碼:
1 def read_data_sets(train_dir, 2 fake_data=False, 3 one_hot=False, 4 dtype=dtypes.float32, 5 reshape=True, 6 validation_size=5000, 7 seed=None, 8 source_url=DEFAULT_SOURCE_URL): 9 if fake_data: 10 ... 11
12 if not source_url: # empty string check
13 ... 14
15 local_file = base.maybe_download(TRAIN_IMAGES, train_dir, 16 source_url + TRAIN_IMAGES) 17 with gfile.Open(local_file, 'rb') as f: 18 train_images = extract_images(f) 19
20 ... 21
22 if not 0 <= validation_size <= len(train_images): 23 raise ValueError('Validation size should be between 0 and {}. Received: {}.'
24 .format(len(train_images), validation_size)) 25
26 validation_images = train_images[:validation_size] 27 validation_labels = train_labels[:validation_size] 28 train_images = train_images[validation_size:] 29 train_labels = train_labels[validation_size:] 30
31 options = dict(dtype=dtype, reshape=reshape, seed=seed) 32
33 train = DataSet(train_images, train_labels, **options) 34 validation = DataSet(validation_images, validation_labels, **options) 35 test = DataSet(test_images, test_labels, **options) 36
37 return base.Datasets(train=train, validation=validation, test=test)
mnist最終得到的是base.Datasets,完成了數據讀取。這里邊的細節還需要完了再仔細研究下。
順便記錄下自編的函數的定義方法:
1 def Mycollect(My , thing): 2 3 try: 4 count = My[thing] 5 except KeyError: 6 count = 0 7 8 return count 9 10 from TestFunction import Mycollect 11 My = {'a':10, 'b':15, 'c':5} 12 thing = 'a' 13 print(Mycollect(My , thing));
第11行的placeholder,需要注意下,是用了占位符,也就是先安排位置,而不先提供具體數據,也就是說都是模型(管道)的構建過程(這里用管道來類比,我覺得比較恰當)。注意下placeholder的語法就可以,指定了type和shape,這里的None表示有多少幅圖片是未知的,也就是說樣本數是未知的。這里的坑在於,如果我們用print看的話會發現,構建的是張量(Tensor)而不是矩陣,這里對熟悉matlab的同學來說可能是個坑。可以注意下張量的定義方式。
第14和15行是定義了變量,如果只看tf.zeros([10])的話也是個張量的,只是外邊又加了變量的聲明。所以后邊可以直接乘的,這個也不難理解了。
第18行的matmul是張量相乘,然后使用了softmax模型,目的是把結果進行概率化。巧妙,只想說這兩個字,這個就是進行歸一化,搞算法這個是比較常用的,學校時候這個詞很火,我們最終想得到的是一個指定的數組,所以用這個模型來匹配我的規則。
21行是什么,看完就知道是實際的輸出,然后在24行做交叉熵。終於又碰到熵這個老朋友了。交叉熵簡單理解為概率分布的距離,在這里作為一個loss_function。第27行使用了梯度下降來優化這個loss_function,最終是想找到最優時候的一個模型,這里的最優指的是通過這個模型,得到的結果和實際值最接近。
第30行,創建一個session。
第33行,初始化變量。
第37行,可以去看下next_batch的源碼,作用是選取100個樣本來訓練。
第41行,注意equal函數的作用,第43行來做類型轉換,然后取平均值。(代碼很巧妙,很優雅,很爽)
最終第44行輸出模型的准確率。
好了,這大概就是我的一點點總結了,算是入了個門,接下來我會更多的舉一反三,深入掌握其精髓,我會努力走得更遠。
作為一個小白,我要繼續努力向大牛學習,吃飯去咯,下周再戰。
