圖像分類數據集


一、前言

1、前廣泛使用的圖像分類數據集之一是 MNIST 數據集,雖然它是很不錯的基准數據集,但按今天的標准,即使是簡單的模型也能達到95%以上的分類准確率,因此不適合區分強模型和弱模型。

2、為了提高難度,我們將在接下來的章節中討論在2017年發布的性質相似但相對復雜的Fashion-MNIST數據集

 

二、讀取數據集

%matplotlib inline
import torch
import torchvision   #pytorch對於計算機視覺模型實現的一些庫
from torch.utils import data
from torchvision import transforms #transforms對數據進行操作
from d2l import torch as d2l

d2l.use_svg_display()#使用svg來顯示圖片,清晰度會高一點

1、通過框架中的內置函數將Fashion-MNIST數據集下載並讀取到內存中

# 通過ToTensor實例將圖像數據從PIL類型變換成32位浮點數格式
# 並除以255使得所有像素的數值均在0到1之間(歸一化)

trans = transforms.ToTensor()

#下載到上一級目錄的data下面
#train=True,表示下載的是訓練數據集
#transform=trans,表示下載的要轉為pytorch的tensor,而不是一堆圖片
#download=True 意思是默認從網上下載
mnist_train = torchvision.datasets.FashionMNIST(root="../data", train=True,
                                                transform=trans,
                                                download=True)

#測試數據集,驗證模型的好壞
mnist_test = torchvision.datasets.FashionMNIST(root="../data", train=False,
                                               transform=trans, download=True)

2、Fashion-MNIST由10個類別的圖像組成,每個類別由訓練數據集中6000張圖像和測試數據集中1000張圖像組成。

len(mnist_train), len(mnist_test)

#下載成功后輸出發現,訓練數據集有60000張圖片,測試數據集有10000張圖片

#輸出結果

(60000, 10000)

3、每個輸入圖像的高度和寬度均為28像素,數據集由灰度圖像組成,其通道為1。

# fashionmnist數據集同時包含圖形和標簽,第二個零表示取圖片,如果是【1】則是標簽

#print(mnist_train[0][0])
mnist_train[0][0].shape
#第一個零:就是第零個樣例
#第二個零:表示取圖片,如果是1就表示標簽

#輸出1表示是黑白圖片
#28,28分表表示長和寬

#輸出結果
torch.Size([1, 28, 28])

4、Fashion-MNIST中包含的10個類別分別為t-shirt(T恤)、trouser(褲子)、pullover(套衫)、dress(連衣裙)、coat(外套)、sandal(涼鞋)、shirt(襯衫)、sneaker(運動鞋)、bag(包)和ankle boot(短靴)。

def get_fashion_mnist_labels(labels):  #@save
    """返回Fashion-MNIST數據集的文本標簽。"""
    text_labels = [
        't-shirt', 'trouser', 'pullover', 'dress', 'coat', 'sandal', 'shirt',
        'sneaker', 'bag', 'ankle boot']
    return [text_labels[int(i)] for i in labels]

5、可視化樣本函數(不求理解)

def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5):  #@save
    """Plot a list of images."""
    figsize = (num_cols * scale, num_rows * scale)
    _, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize)
    axes = axes.flatten()
    for i, (ax, img) in enumerate(zip(axes, imgs)):
        if torch.is_tensor(img):
            # 圖片張量
            ax.imshow(img.numpy())
        else:
            # PIL圖片
            ax.imshow(img)
        ax.axes.get_xaxis().set_visible(False)
        ax.axes.get_yaxis().set_visible(False)
        if titles:
            ax.set_title(titles[i])
    return axes

6、

#構造了pytorch數據集之后放在DateLoader中,指定一個batch_size
#next()拿到第一個小批量。batch_size:批量大小
X, y = next(iter(data.DataLoader(mnist_train, batch_size=18)))

#2行9列
show_images(X.reshape(18, 28, 28), 2, 9, titles=get_fashion_mnist_labels(y));#y是一個數值的標號

#輸出結果

圖片略

7、讀取小批量。為了使我們在讀取訓練集和測試集時更容易,我們使用內置的數據迭代器,而不是從零開始創建一個。

batch_size = 256

def get_dataloader_workers():  #@save
    """使用4個進程來讀取數據。"""
    return 4

#shuffle=True表明需要隨機,打亂順序
#訓練集需要打亂順序,測試集不用打亂順序
#num_workers表明要給多少進程
train_iter = data.DataLoader(mnist_train, batch_size, shuffle=True,
                             num_workers=get_dataloader_workers())

8、查看讀取訓練數據所需時間

#Timer()函數用來測試速度-讀一次數據所用的時間
timer = d2l.Timer()

#構造出train_iter之后,使用for循環一個個來訪問batch
for X, y in train_iter:
    continue
f'{timer.stop():.2f} sec'

#輸出結果
'8.76 sec'

9、整合所有組件

定義 load_data_fashion_mnist 函數],用於獲取和讀取Fashion-MNIST數據集。它返回訓練集和驗證集的數據迭代器。此外,它還接受一個可選參數,用來將圖像大小調整為另一種形狀。

#resize
def load_data_fashion_mnist(batch_size, resize=None):  #@save
    """下載Fashion-MNIST數據集,然后將其加載到內存中。"""
    trans = [transforms.ToTensor()]
    
    # resize:調整圖像大小
    if resize:
        trans.insert(0, transforms.Resize(resize))
        
    trans = transforms.Compose(trans)
    # 下載數據集
    mnist_train = torchvision.datasets.FashionMNIST(root="../data",
                                                    train=True,
                                                    transform=trans,
                                                    download=True)
    mnist_test = torchvision.datasets.FashionMNIST(root="../data",
                                                   train=False,
                                                   transform=trans,
                                                   download=True)
    # 返回小批量
    return (data.DataLoader(mnist_train, batch_size, shuffle=True,
                            num_workers=get_dataloader_workers()),
            data.DataLoader(mnist_test, batch_size, shuffle=False,
                            num_workers=get_dataloader_workers()))

10、通過指定resize參數來測試load_data_fashion_mnist函數的圖像大小調整功能

#32表示為batch_size的大小,resize表示重新調整的圖片大小
train_iter, test_iter = load_data_fashion_mnist(32, resize=64) for X, y in train_iter: print(X.shape, X.dtype, y.shape, y.dtype) brea #輸出結果 torch.Size([32, 1, 64, 64]) torch.float32 torch.Size([32]) torch.int64

  


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM