Fashion MNIST的下載與導入


在動手寫深度學習的TensorFlow實現版本中,需要用到數據集Fashion MNIST,如果直接用TensorFlow導入數據集:

from tensorflow.keras.datasets import fashion_mnist
(x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()

就會報錯,下載數據集時會顯示服務器連接超時,可能因為服務器在國內被牆了。

下面是如何手動下載數據集並導入的步驟:

1.下載數據集

去GitHub上該數據集的主頁下載:https://github.com/zalandoresearch/fashion-mnist

 

 下載完成后解壓放在./data/fashion/文件夾下

 

 接下導入數據集:

 

import mnist_reader

x_train, y_train = mnist_reader.load_mnist('data/fashion', kind='train')
x_test, y_test = mnist_reader.load_mnist('data/fashion', kind='t10k')

注意這里面的mnist_reader是GitHub上該項目里面的一個文件,不要以為是某個庫

 

 可以直接clone整個項目,再把這個文件放在和上文data相同的文件夾下

 

 

 不想下載這個項目呢,這里給出這個文件的具體代碼,在導入數據集時把這個函數加入到你的代碼中也可以:

def load_mnist(path, kind='train'):
    import os
    import gzip
    import numpy as np

    """Load MNIST data from `path`"""
    labels_path = os.path.join(path,
                               '%s-labels-idx1-ubyte.gz'
                               % kind)
    images_path = os.path.join(path,
                               '%s-images-idx3-ubyte.gz'
                               % kind)

    with gzip.open(labels_path, 'rb') as lbpath:
        labels = np.frombuffer(lbpath.read(), dtype=np.uint8,
                               offset=8)

    with gzip.open(images_path, 'rb') as imgpath:
        images = np.frombuffer(imgpath.read(), dtype=np.uint8,
                               offset=16).reshape(len(labels), 784)

    return images, labels

最后可以測試一下是否導入成功:

 

 最后如果你還是導入不成功,或者GitHub上數據集你就是下載不下來,可以私信我。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM