在動手寫深度學習的TensorFlow實現版本中,需要用到數據集Fashion MNIST,如果直接用TensorFlow導入數據集:
from tensorflow.keras.datasets import fashion_mnist (x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()
就會報錯,下載數據集時會顯示服務器連接超時,可能因為服務器在國內被牆了。
下面是如何手動下載數據集並導入的步驟:
1.下載數據集
去GitHub上該數據集的主頁下載:https://github.com/zalandoresearch/fashion-mnist
下載完成后解壓放在./data/fashion/文件夾下
接下導入數據集:
import mnist_reader x_train, y_train = mnist_reader.load_mnist('data/fashion', kind='train') x_test, y_test = mnist_reader.load_mnist('data/fashion', kind='t10k')
注意這里面的mnist_reader是GitHub上該項目里面的一個文件,不要以為是某個庫
可以直接clone整個項目,再把這個文件放在和上文data相同的文件夾下
不想下載這個項目呢,這里給出這個文件的具體代碼,在導入數據集時把這個函數加入到你的代碼中也可以:
def load_mnist(path, kind='train'): import os import gzip import numpy as np """Load MNIST data from `path`""" labels_path = os.path.join(path, '%s-labels-idx1-ubyte.gz' % kind) images_path = os.path.join(path, '%s-images-idx3-ubyte.gz' % kind) with gzip.open(labels_path, 'rb') as lbpath: labels = np.frombuffer(lbpath.read(), dtype=np.uint8, offset=8) with gzip.open(images_path, 'rb') as imgpath: images = np.frombuffer(imgpath.read(), dtype=np.uint8, offset=16).reshape(len(labels), 784) return images, labels
最后可以測試一下是否導入成功:
最后如果你還是導入不成功,或者GitHub上數據集你就是下載不下來,可以私信我。