TensorFlow CNN 測試CIFAR-10數據集



本系列文章由 @yhl_leo 出品,轉載請注明出處。
文章鏈接: http://blog.csdn.net/yhl_leo/article/details/50738311


1 CIFAR-10 數據集

CIFAR-10數據集是機器學習中的一個通用的用於圖像識別的基礎數據集。官網鏈接為:The CIFAR-10 dataset

cifar10

下載使用的版本號是:

version

將其解壓后(代碼中包括自己主動解壓代碼)。內容為:

cifar10 data

cifar10 data2

2 測試代碼

測試代碼發布在GitHub:yhlleo

主要代碼及作用:

文件 作用
cifar10_input.py 讀取本地或者在線下載CIFAR-10的二進制文件格式數據集
cifar10.py 建立CIFAR-10的模型
cifar10_train.py 在CPU或GPU上訓練CIFAR-10的模型
cifar10_multi_gpu_train.py 在多個GPU上訓練CIFAR-10的模型
cifar10_eval.py 評估CIFAR-10模型的預測性能

該部分的代碼,介紹了怎樣使用TensorFlow在CPU和GPU上訓練和評估卷積神經網絡(convolutional neural network, CNN)。

3 相關網頁及教程

更加具體地介紹說明。請瀏覽網頁:Convolutional Neural Networks

中文站點極客學院也有該部分的漢譯版:卷積神經網絡

代碼源自tensorflow官網:tensorflow/models/image/cifar10

4 代碼改動說明

GitHub發布代碼相對源代碼(本人的Tensorflow版本號還是0.5),主要進行了下面修正:

  • cifar10.py
# indices = tf.reshape(tf.range(FLAGS.batch_size), [FLAGS.batch_size, 1])
indices = tf.reshape(range(FLAGS.batch_size), [FLAGS.batch_size, 1])

# or
indices = tf.reshape(tf.range(0, FLAGS.batch_size, 1), [FLAGS.batch_size, 1])

此處,源代碼編譯時會出現下面錯誤:

  ...
  File ".../cifar10.py", line 271, in loss
    indices = tf.reshape(tf.range(FLAGS.batch_size), [FLAGS.batch_size, 1])
TypeError: range() takes at least 2 arguments (1 given)
  • cifar10_input_test.py
#self.a 
posted @ 2017-07-27 15:26  wzzkaifa  閱讀( 1187)  評論( 0編輯  收藏


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM