tensorflow學習筆記三:實例數據下載與讀取


一、mnist數據

深度學習的入門實例,一般就是mnist手寫數字分類識別,因此我們應該先下載這個數據集。

tensorflow提供一個input_data.py文件,專門用於下載mnist數據,我們直接調用就可以了,代碼如下:

import tensorflow.examples.tutorials.mnist.input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

執行完成后,會在當前目錄下新建一個文件夾MNIST_data, 下載的數據將放入這個文件夾內。下載的四個文件為:

input_data文件會調用一個maybe_download函數,確保數據下載成功。這個函數還會判斷數據是否已經下載,如果已經下載好了,就不再重復下載。

下載下來的數據集被分三個子集:5.5W行的訓練數據集(mnist.train),5千行的驗證數據集(mnist.validation)和1W行的測試數據集(mnist.test)。因為每張圖片為28x28的黑白圖片,所以每行為784維的向量。

每個子集都由兩部分組成:圖片部分(images)和標簽部分(labels), 我們可以用下面的代碼來查看 :

print mnist.train.images.shape
print mnist.train.labels.shape
print mnist.validation.images.shape
print mnist.validation.labels.shape
print mnist.test.images.shape
print mnist.test.labels.shape

如果想在spyder編輯器中查看具體數值,可以將這些數據提取為變量來查看,如:

val_data=mnist.validation.images
val_label=mnist.validation.labels

二、CSV數據 

除了mnist手寫字體圖片數據,tf還提供了幾個csv的數據供大家練習,存放路徑為:

/home/xxx/anaconda3/lib/python3.5/site-packages/tensorflow/contrib/learn/python/learn/datasets/data/text_train.csv

如果要將這些數據讀出來,可用代碼:

import tensorflow.contrib.learn.python.learn.datasets.base as base
iris_data,iris_label=base.load_iris()
house_data,house_label=base.load_boston()

前者為iris鳶尾花卉數據集,后者為波士頓房價數據。

三、cifar10數據

tf提供了cifar10數據的下載和讀取的函數,我們直接調用就可以了。執行下列代碼:

import tensorflow.models.image.cifar10.cifar10 as cifar10
cifar10.maybe_download_and_extract()
images, labels = cifar10.distorted_inputs()
print images
print labels

就可以將cifar10下載並讀取出來。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM