一、前言
1、前廣泛使用的圖像分類數據集之一是 MNIST 數據集,雖然它是很不錯的基准數據集,但按今天的標准,即使是簡單的模型也能達到95%以上的分類准確率,因此不適合區分強模型和弱模型。
2、為了提高難度,我們將在接下來的章節中討論在2017年發布的性質相似但相對復雜的Fashion-MNIST數據集
二、讀取數據集
%matplotlib inline import torch import torchvision #pytorch對於計算機視覺模型實現的一些庫 from torch.utils import data from torchvision import transforms #transforms對數據進行操作 from d2l import torch as d2l d2l.use_svg_display()#使用svg來顯示圖片,清晰度會高一點
1、通過框架中的內置函數將Fashion-MNIST數據集下載並讀取到內存中
# 通過ToTensor實例將圖像數據從PIL類型變換成32位浮點數格式
# 並除以255使得所有像素的數值均在0到1之間(歸一化)
trans = transforms.ToTensor()
#下載到上一級目錄的data下面
#train=True,表示下載的是訓練數據集
#transform=trans,表示下載的要轉為pytorch的tensor,而不是一堆圖片
#download=True 意思是默認從網上下載
mnist_train = torchvision.datasets.FashionMNIST(root="../data", train=True,
transform=trans,
download=True)
#測試數據集,驗證模型的好壞
mnist_test = torchvision.datasets.FashionMNIST(root="../data", train=False,
transform=trans, download=True)
2、Fashion-MNIST由10個類別的圖像組成,每個類別由訓練數據集中6000張圖像和測試數據集中1000張圖像組成。
len(mnist_train), len(mnist_test) #下載成功后輸出發現,訓練數據集有60000張圖片,測試數據集有10000張圖片 #輸出結果 (60000, 10000)
3、每個輸入圖像的高度和寬度均為28像素,數據集由灰度圖像組成,其通道為1。
# fashionmnist數據集同時包含圖形和標簽,第二個零表示取圖片,如果是【1】則是標簽 #print(mnist_train[0][0]) mnist_train[0][0].shape #第一個零:就是第零個樣例 #第二個零:表示取圖片,如果是1就表示標簽 #輸出1表示是黑白圖片 #28,28分表表示長和寬
#輸出結果
torch.Size([1, 28, 28])
4、Fashion-MNIST中包含的10個類別分別為t-shirt(T恤)、trouser(褲子)、pullover(套衫)、dress(連衣裙)、coat(外套)、sandal(涼鞋)、shirt(襯衫)、sneaker(運動鞋)、bag(包)和ankle boot(短靴)。
def get_fashion_mnist_labels(labels): #@save
"""返回Fashion-MNIST數據集的文本標簽。"""
text_labels = [
't-shirt', 'trouser', 'pullover', 'dress', 'coat', 'sandal', 'shirt',
'sneaker', 'bag', 'ankle boot']
return [text_labels[int(i)] for i in labels]
5、可視化樣本函數(不求理解)
def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5): #@save
"""Plot a list of images."""
figsize = (num_cols * scale, num_rows * scale)
_, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize)
axes = axes.flatten()
for i, (ax, img) in enumerate(zip(axes, imgs)):
if torch.is_tensor(img):
# 圖片張量
ax.imshow(img.numpy())
else:
# PIL圖片
ax.imshow(img)
ax.axes.get_xaxis().set_visible(False)
ax.axes.get_yaxis().set_visible(False)
if titles:
ax.set_title(titles[i])
return axes
6、
#構造了pytorch數據集之后放在DateLoader中,指定一個batch_size #next()拿到第一個小批量。batch_size:批量大小 X, y = next(iter(data.DataLoader(mnist_train, batch_size=18))) #2行9列 show_images(X.reshape(18, 28, 28), 2, 9, titles=get_fashion_mnist_labels(y));#y是一個數值的標號 #輸出結果 圖片略
7、讀取小批量。為了使我們在讀取訓練集和測試集時更容易,我們使用內置的數據迭代器,而不是從零開始創建一個。
batch_size = 256
def get_dataloader_workers(): #@save
"""使用4個進程來讀取數據。"""
return 4
#shuffle=True表明需要隨機,打亂順序
#訓練集需要打亂順序,測試集不用打亂順序
#num_workers表明要給多少進程
train_iter = data.DataLoader(mnist_train, batch_size, shuffle=True,
num_workers=get_dataloader_workers())
8、查看讀取訓練數據所需時間
#Timer()函數用來測試速度-讀一次數據所用的時間
timer = d2l.Timer()
#構造出train_iter之后,使用for循環一個個來訪問batch
for X, y in train_iter:
continue
f'{timer.stop():.2f} sec'
#輸出結果
'8.76 sec'
9、整合所有組件
定義 load_data_fashion_mnist 函數],用於獲取和讀取Fashion-MNIST數據集。它返回訓練集和驗證集的數據迭代器。此外,它還接受一個可選參數,用來將圖像大小調整為另一種形狀。
#resize
def load_data_fashion_mnist(batch_size, resize=None): #@save
"""下載Fashion-MNIST數據集,然后將其加載到內存中。"""
trans = [transforms.ToTensor()]
# resize:調整圖像大小
if resize:
trans.insert(0, transforms.Resize(resize))
trans = transforms.Compose(trans)
# 下載數據集
mnist_train = torchvision.datasets.FashionMNIST(root="../data",
train=True,
transform=trans,
download=True)
mnist_test = torchvision.datasets.FashionMNIST(root="../data",
train=False,
transform=trans,
download=True)
# 返回小批量
return (data.DataLoader(mnist_train, batch_size, shuffle=True,
num_workers=get_dataloader_workers()),
data.DataLoader(mnist_test, batch_size, shuffle=False,
num_workers=get_dataloader_workers()))
10、通過指定resize參數來測試load_data_fashion_mnist函數的圖像大小調整功能
#32表示為batch_size的大小,resize表示重新調整的圖片大小
train_iter, test_iter = load_data_fashion_mnist(32, resize=64) for X, y in train_iter: print(X.shape, X.dtype, y.shape, y.dtype) brea #輸出結果 torch.Size([32, 1, 64, 64]) torch.float32 torch.Size([32]) torch.int64
