詳解 MNIST 數據集


轉自:https://blog.csdn.net/simple_the_best/article/details/75267863

MNIST 數據集已經是一個被”嚼爛”了的數據集, 很多教程都會對它”下手”, 幾乎成為一個 “典范”. 不過有些人可能對它還不是很了解, 下面來介紹一下.

MNIST 數據集可在 http://yann.lecun.com/exdb/mnist/ 獲取, 它包含了四個部分:

  • Training set images: train-images-idx3-ubyte.gz (9.9 MB, 解壓后 47 MB, 包含 60,000 個樣本)
  • Training set labels: train-labels-idx1-ubyte.gz (29 KB, 解壓后 60 KB, 包含 60,000 個標簽)
  • Test set images: t10k-images-idx3-ubyte.gz (1.6 MB, 解壓后 7.8 MB, 包含 10,000 個樣本)
  • Test set labels: t10k-labels-idx1-ubyte.gz (5KB, 解壓后 10 KB, 包含 10,000 個標簽)

MNIST 數據集來自美國國家標准與技術研究所, National Institute of Standards and Technology (NIST). 訓練集 (training set) 由來自 250 個不同人手寫的數字構成, 其中 50% 是高中學生, 50% 來自人口普查局 (the Census Bureau) 的工作人員. 測試集(test set) 也是同樣比例的手寫數字數據.

不妨新建一個文件夾 – mnist, 將數據集下載到 mnist 以后, 解壓即可:

dataset

圖片是以字節的形式進行存儲, 我們需要把它們讀取到 NumPy array 中, 以便訓練和測試算法.

import os import struct import numpy as np def load_mnist(path, kind='train'): """Load MNIST data from `path`""" labels_path = os.path.join(path, '%s-labels-idx1-ubyte' % kind) images_path = os.path.join(path, '%s-images-idx3-ubyte' % kind) with open(labels_path, 'rb') as lbpath: magic, n = struct.unpack('>II', lbpath.read(8)) labels = np.fromfile(lbpath, dtype=np.uint8) with open(images_path, 'rb') as imgpath: magic, num, rows, cols = struct.unpack('>IIII', imgpath.read(16)) images = np.fromfile(imgpath, dtype=np.uint8).reshape(len(labels), 784) return images, labels

load_mnist 函數返回兩個數組, 第一個是一個 n x m 維的 NumPy array(images), 這里的 n 是樣本數(行數), m 是特征數(列數). 訓練數據集包含 60,000 個樣本, 測試數據集包含 10,000 樣本. 在 MNIST 數據集中的每張圖片由 28 x 28 個像素點構成, 每個像素點用一個灰度值表示. 在這里, 我們將 28 x 28 的像素展開為一個一維的行向量, 這些行向量就是圖片數組里的行(每行 784 個值, 或者說每行就是代表了一張圖片). load_mnist 函數返回的第二個數組(labels) 包含了相應的目標變量, 也就是手寫數字的類標簽(整數 0-9).

第一次見的話, 可能會覺得我們讀取圖片的方式有點奇怪:

magic, n = struct.unpack('>II', lbpath.read(8)) labels = np.fromfile(lbpath, dtype=np.uint8)

為了理解這兩行代碼, 我們先來看一下 MNIST 網站上對數據集的介紹:

TRAINING SET LABEL FILE (train-labels-idx1-ubyte): [offset] [type] [value] [description] 0000 32 bit integer 0x00000801(2049) magic number (MSB first) 0004 32 bit integer 60000 number of items 0008 unsigned byte ?? label 0009 unsigned byte ?? label ........ xxxx unsigned byte ?? label The labels values are 0 to 9.

通過使用上面兩行代碼, 我們首先讀入 magic number, 它是一個文件協議的描述, 也是在我們調用 fromfile 方法將字節讀入 NumPy array 之前在文件緩沖中的 item 數(n). 作為參數值傳入 struct.unpack 的 >II 有兩個部分:

  • >: 這是指大端(用來定義字節是如何存儲的); 如果你還不知道什么是大端和小端, Endianness 是一個非常好的解釋. (關於大小端, 更多內容可見<<深入理解計算機系統 – 2.1 節信息存儲>>)
  • I: 這是指一個無符號整數.

通過執行下面的代碼, 我們將會從剛剛解壓 MNIST 數據集后的 mnist 目錄下加載 60,000 個訓練樣本和 10,000 個測試樣本.

為了了解 MNIST 中的圖片看起來到底是個啥, 讓我們來對它們進行可視化處理. 從 feature matrix 中將 784-像素值 的向量 reshape 為之前的 28*28 的形狀, 然后通過 matplotlib 的 imshow 函數進行繪制:

import matplotlib.pyplot as plt fig, ax = plt.subplots( nrows=2, ncols=5, sharex=True, sharey=True, ) ax = ax.flatten() for i in range(10): img = X_train[y_train == i][0].reshape(28, 28) ax[i].imshow(img, cmap='Greys', interpolation='nearest') ax[0].set_xticks([]) ax[0].set_yticks([]) plt.tight_layout() plt.show()

我們現在應該可以看到一個 2*5 的圖片, 里面分別是 0-9 單個數字的圖片.

0-9

此外, 我們還可以繪制某一數字的多個樣本圖片, 來看一下這些手寫樣本到底有多不同:

fig, ax = plt.subplots(
    nrows=5, ncols=5, sharex=True, sharey=True, ) ax = ax.flatten() for i in range(25): img = X_train[y_train == 7][i].reshape(28, 28) ax[i].imshow(img, cmap='Greys', interpolation='nearest') ax[0].set_xticks([]) ax[0].set_yticks([]) plt.tight_layout() plt.show()

執行上面的代碼后, 我們應該看到數字 7 的 25 個不同形態:

7

另外, 我們也可以選擇將 MNIST 圖片數據和標簽保存為 CSV 文件, 這樣就可以在不支持特殊的字節格式的程序中打開數據集. 但是, 有一點要說明, CSV 的文件格式將會占用更多的磁盤空間, 如下所示:

  • train_img.csv: 109.5 MB
  • train_labels.csv: 120 KB
  • test_img.csv: 18.3 MB
  • test_labels: 20 KB

如果我們打算保存這些 CSV 文件, 在將 MNIST 數據集加載入 NumPy array 以后, 我們應該執行下列代碼:

np.savetxt('train_img.csv', X_train, fmt='%i', delimiter=',') np.savetxt('train_labels.csv', y_train, fmt='%i', delimiter=',') np.savetxt('test_img.csv', X_test, fmt='%i', delimiter=',') np.savetxt('test_labels.csv', y_test, fmt='%i', delimiter=',')

一旦將數據集保存為 CSV 文件, 我們也可以用 NumPy 的 genfromtxt 函數重新將它們加載入程序中:

X_train = np.genfromtxt('train_img.csv', dtype=int, delimiter=',') y_train = np.genfromtxt('train_labels.csv', dtype=int, delimiter=',') X_test = np.genfromtxt('test_img.csv', dtype=int, delimiter=',') y_test = np.genfromtxt('test_labels.csv', dtype=int, delimiter=',')

不過, 從 CSV 文件中加載 MNIST 數據將會顯著發給更長的時間, 因此如果可能的話, 還是建議你維持數據集原有的字節格式.

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM