一、前言
1、前廣泛使用的圖像分類數據集之一是 MNIST 數據集,雖然它是很不錯的基准數據集,但按今天的標准,即使是簡單的模型也能達到95%以上的分類准確率,因此不適合區分強模型和弱模型。
2、為了提高難度,我們將在接下來的章節中討論在2017年發布的性質相似但相對復雜的Fashion-MNIST數據集
二、讀取數據集
%matplotlib inline import torch import torchvision #pytorch對於計算機視覺模型實現的一些庫 from torch.utils import data from torchvision import transforms #transforms對數據進行操作 from d2l import torch as d2l d2l.use_svg_display()#使用svg來顯示圖片,清晰度會高一點
1、通過框架中的內置函數將Fashion-MNIST數據集下載並讀取到內存中
# 通過ToTensor實例將圖像數據從PIL類型變換成32位浮點數格式 # 並除以255使得所有像素的數值均在0到1之間(歸一化) trans = transforms.ToTensor() #下載到上一級目錄的data下面 #train=True,表示下載的是訓練數據集 #transform=trans,表示下載的要轉為pytorch的tensor,而不是一堆圖片 #download=True 意思是默認從網上下載 mnist_train = torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=True) #測試數據集,驗證模型的好壞 mnist_test = torchvision.datasets.FashionMNIST(root="../data", train=False, transform=trans, download=True)
2、Fashion-MNIST由10個類別的圖像組成,每個類別由訓練數據集中6000張圖像和測試數據集中1000張圖像組成。
len(mnist_train), len(mnist_test) #下載成功后輸出發現,訓練數據集有60000張圖片,測試數據集有10000張圖片 #輸出結果 (60000, 10000)
3、每個輸入圖像的高度和寬度均為28像素,數據集由灰度圖像組成,其通道為1。
# fashionmnist數據集同時包含圖形和標簽,第二個零表示取圖片,如果是【1】則是標簽 #print(mnist_train[0][0]) mnist_train[0][0].shape #第一個零:就是第零個樣例 #第二個零:表示取圖片,如果是1就表示標簽 #輸出1表示是黑白圖片 #28,28分表表示長和寬
#輸出結果
torch.Size([1, 28, 28])
4、Fashion-MNIST中包含的10個類別分別為t-shirt(T恤)、trouser(褲子)、pullover(套衫)、dress(連衣裙)、coat(外套)、sandal(涼鞋)、shirt(襯衫)、sneaker(運動鞋)、bag(包)和ankle boot(短靴)。
def get_fashion_mnist_labels(labels): #@save """返回Fashion-MNIST數據集的文本標簽。""" text_labels = [ 't-shirt', 'trouser', 'pullover', 'dress', 'coat', 'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot'] return [text_labels[int(i)] for i in labels]
5、可視化樣本函數(不求理解)
def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5): #@save """Plot a list of images.""" figsize = (num_cols * scale, num_rows * scale) _, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize) axes = axes.flatten() for i, (ax, img) in enumerate(zip(axes, imgs)): if torch.is_tensor(img): # 圖片張量 ax.imshow(img.numpy()) else: # PIL圖片 ax.imshow(img) ax.axes.get_xaxis().set_visible(False) ax.axes.get_yaxis().set_visible(False) if titles: ax.set_title(titles[i]) return axes
6、
#構造了pytorch數據集之后放在DateLoader中,指定一個batch_size #next()拿到第一個小批量。batch_size:批量大小 X, y = next(iter(data.DataLoader(mnist_train, batch_size=18))) #2行9列 show_images(X.reshape(18, 28, 28), 2, 9, titles=get_fashion_mnist_labels(y));#y是一個數值的標號 #輸出結果 圖片略
7、讀取小批量。為了使我們在讀取訓練集和測試集時更容易,我們使用內置的數據迭代器,而不是從零開始創建一個。
batch_size = 256 def get_dataloader_workers(): #@save """使用4個進程來讀取數據。""" return 4 #shuffle=True表明需要隨機,打亂順序 #訓練集需要打亂順序,測試集不用打亂順序 #num_workers表明要給多少進程 train_iter = data.DataLoader(mnist_train, batch_size, shuffle=True, num_workers=get_dataloader_workers())
8、查看讀取訓練數據所需時間
#Timer()函數用來測試速度-讀一次數據所用的時間 timer = d2l.Timer() #構造出train_iter之后,使用for循環一個個來訪問batch for X, y in train_iter: continue f'{timer.stop():.2f} sec'
#輸出結果
'8.76 sec'
9、整合所有組件
定義 load_data_fashion_mnist
函數],用於獲取和讀取Fashion-MNIST數據集。它返回訓練集和驗證集的數據迭代器。此外,它還接受一個可選參數,用來將圖像大小調整為另一種形狀。
#resize def load_data_fashion_mnist(batch_size, resize=None): #@save """下載Fashion-MNIST數據集,然后將其加載到內存中。""" trans = [transforms.ToTensor()] # resize:調整圖像大小 if resize: trans.insert(0, transforms.Resize(resize)) trans = transforms.Compose(trans) # 下載數據集 mnist_train = torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=True) mnist_test = torchvision.datasets.FashionMNIST(root="../data", train=False, transform=trans, download=True) # 返回小批量 return (data.DataLoader(mnist_train, batch_size, shuffle=True, num_workers=get_dataloader_workers()), data.DataLoader(mnist_test, batch_size, shuffle=False, num_workers=get_dataloader_workers()))
10、通過指定resize參數來測試load_data_fashion_mnist
函數的圖像大小調整功能
#32表示為batch_size的大小,resize表示重新調整的圖片大小
train_iter, test_iter = load_data_fashion_mnist(32, resize=64) for X, y in train_iter: print(X.shape, X.dtype, y.shape, y.dtype) brea #輸出結果 torch.Size([32, 1, 64, 64]) torch.float32 torch.Size([32]) torch.int64