1 import sys, os 2 sys.path.append(os.pardir) 3 import numpy as np 4 from dataset.mnist import load_mnist 5 from PIL import Image 6 7 def img_show(img): 8 pil_img = Image.fromarray(np.uint8(img)) 9 pil_img.show() 10 11 (x_train, t_train), (x_test, t_test) = load_mnist(flatten=True, 12 normalize=False) 13 img = x_train[0] 14 label = t_train[0] 15 print(label) # 5 16 17 print(img.shape) # (784,) 18 img = img.reshape(28, 28) # 把圖像的形狀變成原來的尺寸 19 print(img.shape) # (28, 28) 20 21 img_show(img)
顯示mnist圖像,執行上述代碼后,訓練圖像的第一張會顯示出來。sys.path.append(os.pardir)導入父目錄,第一次調用load_mnist函數時,因為要下載MNIST數據集,所以需要聯網進行。第2次及以后的調用只需要讀入保存在本地的文件(pickle文件)即可,因此處理所需時間都非常短。
load_mnist函數以“(訓練圖像,訓練標簽),(測試圖像,測試標簽)”的形式返回讀入的MNST數據。此外,還可以像
load_mnist(normalize=True, flatten=True, one_hot_label=False)
這樣,設置3個參數。第1個參數normalize設置是否將輸入圖像正規化為0.0~1.0的值。如果將該參數設置為False,則輸入圖像的像素會保持原來的0~255。第2個參數flatten設置是否展開輸入圖像(變成一維數據)。如果將該參數設置為False,則輸入圖像為1 × 28 × 28的三維數組;若設置為True,則輸入圖像會保存為由784個元素構成的一位數組。第3個參數one_hot_label設置是否將標簽保存為one-hot表示(one-hot representation)。one-hot表示是僅正確解標簽為1,其余皆為0的數組,就像[0,0,1,0,0,0,0,0,0,0]這樣。當one_hot_label為False時,知識想7,2這樣簡單保存正確解標簽;當one_hot_label為True時,標簽則保存為one-hot表示。
Python 有 pickle 這個便利的功能。這個功能可以將程序運行中的對象保存為文件。如果加載保存過的 pickle 文件,可以立刻復原之前程序運行中的對象。用於讀入 MNIST 數據集的
load_mnist()
函數內部也使用了 pickle 功能(在第 2 次及以后讀入時)。利用 pickle 功能,可以高效地完成 MNIST 數據的准備工作。
這里需要注意的是,flatten=True
時讀入的圖像是以一列(一維)NumPy 數組的形式保存的。因此,顯示圖像時,需要把它變為原來的 28 像素 × 28 像素的形狀。可以通過 reshape()
方法的參數指定期望的形狀,更改 NumPy 數組的形狀。此外,還需要把保存為 NumPy 數組的圖像數據轉換為 PIL 用的數據對象,這個轉換處理由 Image.fromarray()
來完成。
img = x_train[0] #x_train的形狀是(6000,784),即6000行728列的矩陣,所以x_train[0]表示第一列的784個數據
label = t_train[0] #t_train的形狀是(6000,),即一行或者一列數據6000個,所以t_train[0]是第一個數據,這里它的值是5