TensorFlow 入門之手寫識別(MNIST) 數據處理 一
准備數據
MNIST是在機器學習領域中的一個經典問題。該問題解決的是把28x28像素的灰度手寫數字圖片識別為相應的數字,其中數字的范圍從0到9.

- from IPython.display import Image
- import base64
- Image(data=base64.decodestring(url),embed=True)
同時我們可以通過TensorFlow提供的例子來下載有Yann LeCun提供的MNIST提供的如上的手寫數據集。
| 文件 | 內容 |
|---|---|
| train-images-idx3-ubyte.gz | 訓練集圖片 - 55000 張 訓練圖片, 5000 張 驗證圖片 |
| train-labels-idx1-ubyte.gz | 訓練集圖片對應的數字標簽 |
| t10k-images-idx3-ubyte.gz | 測試集圖片 - 10000 張 圖片 |
| t10k-labels-idx1-ubyte.gz | 測試集圖片對應的數字標簽 |
- import os
- import urllib
-
- SOURCE_URL = 'http://yann.lecun.com/exdb/mnist/'
- # WORK_DIRECTORY = "/tmp/mnist-data"
- WORK_DIRECTORY = '/home/fly/TensorFlow/mnist'
-
- def maybe_download(filename):
- """A helper to download the data files if not present."""
- if not os.path.exists(WORK_DIRECTORY):
- os.mkdir(WORK_DIRECTORY)
- filepath = os.path.join(WORK_DIRECTORY, filename)
- if not os.path.exists(filepath):
- filepath, _ = urllib.urlretrieve(SOURCE_URL + filename, filepath)
- statinfo = os.stat(filepath)
- print 'Succesfully downloaded', filename, statinfo.st_size, 'bytes.'
- else:
- print 'Already downloaded', filename
- return filepath
-
- train_data_filename = maybe_download('train-images-idx3-ubyte.gz')
- train_labels_filename = maybe_download('train-labels-idx1-ubyte.gz')
- test_data_filename = maybe_download('t10k-images-idx3-ubyte.gz')
- test_labels_filename = maybe_download('t10k-labels-idx1-ubyte.gz')
-
解壓 與 重構
數據被解壓成一個二維的Tensor:[image, index, pixel, index],pixel 列是像素的點。0表示的是背景色(白色),255表示的是前景色(黑色)。
上面下載的數據是壓縮的格式,我們需要解壓它。而且每一幅圖是值為[0...255],我們要將它們縮放到[-0.5, 0.5]之間。
我們可以來看一下圖片解壓后的文件格式:
[offset] [type] [value] [description]
0000 32 bit integer 0x00000803(2051) magic number
0004 32 bit integer 60000 number of images
0008 32 bit integer 28 number of rows
0012 32 bit integer 28 number of columns
0016 unsigned byte ?? pixel
0017 unsigned byte ?? pixel
........
xxxx unsigned byte ?? pixel
對應的代碼是:
- import gzip, binascii, struct, numpy
- import matplotlibpyplot as plt
-
- with gzip.open(test_data_filename) as f:
- # 打印出解壓后的圖片格式的頭格式
- for field in ['magic number', 'image count', 'rows', 'columns']:
- # struct.unpack reads the binary data provided by f.read.
- # The format string '>i' decodes a big-endian integer, which
- # is the encoding of the data.
- print field, struct.unpack('>i', f.read(4))[0]
-
- buf = f.read(28*28)
- image = numpy.frombuffer(buf, dtype=numpy.uint8)
-
- # 打印出前十個image的數據
- print 'First 10 pixels: ', iamge[:10]
-
- # ==>
- # magic number 2051
- # image count 10000
- # rows 28
- # columns 28
- # First 10 pixels: [0 0 0 0 0 0 0 0 0 0]
-
當然我們也可以將解壓后的圖給繪制出來
- # 我們將繪制圖以及關於這個圖的直方圖
- _, (ax1, ax2) = plt.subplots(1,2)
- ax1.imshow(image.reshape(28,28), cmap=plt.cm.Greys)
-
- ax2.hist(image, bins=20, range=[0,255])
-
我們也可以將reScale后的數據繪制出來看看
- # Let's convert the uint8 image to 32 bit floats and rescale
- # the values to be centered around 0, between [-0.5, 0.5].
- #
- # We again plot the image and histogram to check that we
- # haven't mangled the data.
- scaled = image.astype(numpy.float32)
- scaled = (scaled - (255 / 2.0)) / 255
- _, (ax1, ax2) = plt.subplots(1, 2)
- ax1.imshow(scaled.reshape(28, 28), cmap=plt.cm.Greys);
- ax2.hist(scaled, bins=20, range=[-0.5, 0.5]);
-
具體的數據使用可以看TensorFlow提供的測試代碼[在IPython中是第三個實例3_mnist_from_scratch.ipynb]
手寫識別入門
MNIST手寫數據集
MNIST數據集的官網是Yann LeCun's website。 在這里,我們提供了一份python源代碼用於自動下載和安裝這個數據集。你可以下載這份代碼,然后用下面的代碼導入到你的項目里面,也可以直接復制粘貼到你的代碼文件里面。(當然你也可以使用前面提到的代碼來下載手寫的數字數據)
- from tensorflow.examples.tutorials.mnist import input_data
-
- # 通過指定下載地址就可以下載數據
- mnist = input_data.read_data_sets("/home/fly/TensorFlow/mnist", one_hot=True)
-
圖片以及標簽的數據格式處理
下載解壓后,得到的數據分為兩部分,60000行的訓練集(mnist.train)和10000行的測試數據集(mnist.test)。由前面的介紹可以知道,每個MNIST數據有兩部分組成:一個手寫數字的圖片以及一個對應的標簽。比如在訓練集中數據圖片為mnist.train.images以及標簽mnist.train.labels.
因為每一張圖片是28 x 28的像素,所以我們可以使用一個數字數組來表示這張圖:
然后我們再把這個數組展開為長度為28 * 28 = 784 的 一維向量。因此,在MNIST訓練數據集中,mnist.train.images 是一個形狀為 [60000, 784] 的張量,第一個維度數字用來索引圖片,第二個維度數字用來索引每張圖片中的像素點。在此張量里的每一個元素,都表示某張圖片里的某個像素的強度值,值介於0和1之間。

相對應的MNIST數據集的標簽是介於0到9的數字,用來描述給定圖片里表示的數字。比如,標簽0將表示成([1,0,0,0,0,0,0,0,0,0,0]).因此, mnist.train.labels 是一個 [60000, 10] 的數字矩陣。

Fly
2016.6

