本系列文章由 @yhl_leo 出品,轉載請注明出處。
文章鏈接: http://blog.csdn.net/yhl_leo/article/details/50738311
1 CIFAR-10 數據集
CIFAR-10數據集是機器學習中的一個通用的用於圖像識別的基礎數據集。官網鏈接為:The CIFAR-10 dataset

下載使用的版本號是:

將其解壓后(代碼中包括自己主動解壓代碼)。內容為:


2 測試代碼
測試代碼發布在GitHub:yhlleo
主要代碼及作用:
| 文件 | 作用 |
|---|---|
cifar10_input.py |
讀取本地或者在線下載CIFAR-10的二進制文件格式數據集 |
cifar10.py |
建立CIFAR-10的模型 |
cifar10_train.py |
在CPU或GPU上訓練CIFAR-10的模型 |
cifar10_multi_gpu_train.py |
在多個GPU上訓練CIFAR-10的模型 |
cifar10_eval.py |
評估CIFAR-10模型的預測性能 |
該部分的代碼,介紹了怎樣使用TensorFlow在CPU和GPU上訓練和評估卷積神經網絡(convolutional neural network, CNN)。
3 相關網頁及教程
更加具體地介紹說明。請瀏覽網頁:Convolutional Neural Networks
中文站點極客學院也有該部分的漢譯版:卷積神經網絡
代碼源自tensorflow官網:tensorflow/models/image/cifar10
4 代碼改動說明
GitHub發布代碼相對源代碼(本人的Tensorflow版本號還是0.5),主要進行了下面修正:
cifar10.py
# indices = tf.reshape(tf.range(FLAGS.batch_size), [FLAGS.batch_size, 1])
indices = tf.reshape(range(FLAGS.batch_size), [FLAGS.batch_size, 1])
# or
indices = tf.reshape(tf.range(0, FLAGS.batch_size, 1), [FLAGS.batch_size, 1])
此處,源代碼編譯時會出現下面錯誤:
...
File ".../cifar10.py", line 271, in loss
indices = tf.reshape(tf.range(FLAGS.batch_size), [FLAGS.batch_size, 1])
TypeError: range() takes at least 2 arguments (1 given)
cifar10_input_test.py
#self.a
