FashionMNIST數據集
Fashion-MNIST是一個10類服飾分類數據集, 我們可以使用它來檢驗不同算法的表現, 這是MNIST數據集不能做到的(原因在這里,想了解的可以看看介紹)。
torchvision的結構
torchvision包包含了很多圖像相關的數據集以及處理方法, 並且有常用的模型結構。
-
torchvision包,它是服務於PyTorch深度學習框架的,主要用來構建計算機視覺模型。torchvision主要由以下幾部分構成:
-
torchvision.datasets: 一些加載數據的函數及常用的數據集接口;
-
torchvision.models: 包含常用的模型結構(含預訓練模型),例如AlexNet、VGG、ResNet等;
-
torchvision.transforms: 常用的圖片變換,例如裁剪、旋轉等;
-
torchvision.utils: 其他的一些有用的方法。
# 導入需要的包
import torch
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import FashionMNIST
import matplotlib.pyplot as plt
加載數據
設置數據的緩存目錄為 root_dir
隨后獲得訓練集和測試集數據,第一次運行的時候會下載 FashionMNIST 數據集到指定的目錄下
下載速度慢解決方案: Gitee 極速下載 Fashion-MNIST
將Fashion-MNIST/ data / fashion的四個壓縮文件解壓到指定的目錄,不要刪除原來的壓縮包文件,因此數據集總共有八個文件。
# 通過標簽得到描述語句
def get_f_mnist_labels(labels):
"""
:param labels: 圖片對應的標簽(0-9的數字)
:return: 標簽對應的描述
"""
text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat', 'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
return [text_labels[int(i)] for i in labels]
def show_fashion_mnist(images, labels):
"""
:param images: 讀取的圖片
:param labels: 圖片對應的標簽
:return: None, 輸出圖片,並且在圖片上方對應標簽給出描述
"""
_, figs = plt.subplots(1, len(images), figsize=(12, 2))
for f, img, lbl in zip(figs, images, labels):
f.imshow(img.view((28, 28)))
f.set_title(lbl)
f.axes.get_xaxis().set_visible(False)
f.axes.get_yaxis().set_visible(False)
plt.show()
root_dir = "./torchvision/data/"
f_mnist_train = FashionMNIST(root=root_dir, train=True, download=True, transform=transforms.ToTensor())
f_mnist_test = FashionMNIST(root=root_dir, train=False, download=True, transform=transforms.ToTensor())
print("f_mnist_train length:", len(f_mnist_train), end='\n')
print("f_mnist_test length:", len(f_mnist_test), end='\n')
x, y = [], []
for i in range(10):
x.append(f_mnist_train[i][0])
y.append(f_mnist_train[i][1])
show_fashion_mnist(x, get_f_mnist_labels(y))
f_mnist_train length: 60000
f_mnist_test length: 10000
讀取小批量數據
from torch.utils.data import DataLoader
batch_size = 256
train_iter = DataLoader(f_mnist_train, batch_size, shuffle=True, num_workers = 0)
# 計算加載數據的時間
import time
start = time.time()
for X, y in train_iter:
continue
print("read train data cost %.4f seconds" % (time.time()-start))
read train data cost 4.9213 seconds
注意
本章的介紹思路來自 Apple Store的 “Python AI” app, 作為學習目的使用, 以及在此文章中記錄學習過程(如有侵權,請聯系作者刪除。)