Python機器學習(七十三)Keras 加載MNIST數據集


MNIST是一個經典的深度學習和計算機視覺的數據集,里面包含了0-9的手寫數字圖片,開發人員可使用此數據集來訓練和測試神經網絡,訓練后的神經網絡可以識別手寫數字。

Keras庫已經包含了這個數據集,可以從Keras庫中加載:

from keras.datasets import mnist

# 將預打亂的MNIST數據加載到訓練和測試集中
(X_train, y_train), (X_test, y_test) = mnist.load_data()
# Downloading data from https://s3.amazonaws.com/img-datasets/mnist.npz
# 11493376/11490434 [==============================] - 483s 42us/step

可以查看數據集的形狀:

print (X_train.shape)
# (60000, 28, 28)

可以看到,訓練集中有60000個樣本,每個圖像都是28像素x28像素。要查看手寫數字圖像,可以使用matplotlib繪制,下面繪制MNIST數據集中的第一個圖像:

from matplotlib import pyplot as plt
plt.imshow(X_train[0])
plt.show()

這是圖像輸出:

圖

一般來說,當開發深度學習應用時,在進行任何算法工作之前,可視化地繪制數據是很有用的。這是一個快速的完整性檢查,可以防止低級錯誤(比如搞錯數據維度)。

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM