3. Dataset、transform和Dataloader的聯立使用


在前兩篇我博客1.法寶函數、編譯器的初級使用和使用Dataset2. tensorboard和 transform的使用中,我分別介紹了 Dataset 和 transform 的簡單使用,並推薦使用了 pytorch 中常用的日志工具 tensorboard,在本篇博客中,我將繼續介紹 Dataset 和 Dataloader的使用,主要介紹數據的加載方式。

1. Datasets + transform

torch.utils.data.Dataset 的回顧

  老辦法,我們還是先查看官方文檔
pytorch.org 上的介紹:

pycharm 上的介紹:

這兩種介紹大同小異,就不再二次翻譯了。
這個是我們之前學的 Dataset,直接進行繼承該類
手寫一個代碼看看:

from PIL import Image
from torch.utils.data import Dataset
import os


class MyDataset(Dataset):
    def __init__(self, root_dir, label):
        """傳遞文件所在位置的參數"""
        self.root_dir = root_dir
        self.label = label
        self.label_path = os.path.join(self.root_dir, self.label)
        self.img_name_list = os.listdir(self.label_path)

    def __getitem__(self, idx):
        """根據文件所在的位置,讀取文件后返回"""
        img_path = os.path.join(self.label_path, self.img_name_list[idx])
        return Image.open(img_path), self.label

    def __len__(self):
        return len(self.img_name_list)


if "__main__" == __name__:
    my_data = MyDataset("./data/train", "ants")
    img_data, label_data = my_data[0]
    img_data.show()

transform 的回顧

transform使我們剛剛學過的,就是調用 transform 庫中的幫助文檔,如果不太了解的話可以查看一下我的上一篇博客2. tensorboard和 transform的使用

Datasets

  請注意,這里是 Datasets 而不是 Dataset,這是一個復數!我們首先查看一下他的幫助文檔

下面,我們以 torchvision.datasets.CIFAR10 為例,進行展示使用方法

CIFAR10 幫助文檔

transform 和 target_transform 處理的對象不同,一個是 Image,一個是 image 對應的類別 target

CIFAR10的使用代碼

from torchvision import datasets
from torchvision import transforms
from PIL import Image
import cv2
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import Dataset


root_dir = "data_cifar10"
data_train = datasets.CIFAR10(root=root_dir, train=True, transform=transforms.ToTensor(), download=True)
data_test = datasets.CIFAR10(root=root_dir, train=False, transform=transforms.ToTensor(), download=True)

targets = []
img, target = data_test[0]
trans_to_PIL = transforms.ToPILImage()
trans_to_PIL(img).show()

for i in range(10):
    targets.append(data_test[i][1])
print(targets)

debug 過程中, 查看 data_train 的屬性

建議

  1. download=True
      當我們使用 dataset的 時候,最好設置 download=True,通過設置參數為True,使得我們給定的 root 文件夾不存在數值的時候,自動下載文件,倘若存在的時候,並不會下載文件;也就是說,download=True,代碼不容易出現一下奇奇怪怪的錯誤

  2. 資源下載
      當我們下載文件比較慢的時候,可以通過自己手動下載,或者是借助一些下載工具完成下載,然后將其復制到我們的目標目錄下即可。這主要是考慮到使用一些下載器其實能夠起到一下加速下載的效果。
      文件的下載鏈接也比較容易找到,如下圖所示:

或者是進入我們的pycharm 幫助文檔進行查找:

2. Dataloader

  PyTorch中數據讀取的一個重要接口是torch.utils.data.DataLoader,主要用來將自定義的數據讀取接口的輸出或者PyTorch已有的數據讀取接口的輸入按照batch size封裝成Tensor,后續只需要再包裝成Variable即可作為模型的輸入。
  下面,我們打開官方文檔來一探究竟。

Dataloader 官方文檔

Dataloader 的使用

import cv2
from PIL import Image
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torch.utils.tensorboard import SummaryWriter


root_dir = "./data_cifar10"
dataset_train = datasets.CIFAR10(root=root_dir, train=True, transform=transforms.ToTensor(), download=True)
dataset_test = datasets.CIFAR10(root=root_dir, train=False, transform=transforms.ToTensor(), download=True)

dataloader_train = DataLoader(dataset=dataset_train, batch_size=64, shuffle=True, drop_last=True)
dataloader_test = DataLoader(dataset=dataset_test, batch_size=64, shuffle=True, drop_last=True)

log_dir = "logs"
writer = SummaryWriter(log_dir=log_dir)
for epoch in range(2):
    step = 0
    for data_imgs, data_targets in dataloader_test:
        writer.add_images(f"epoch{epoch}", data_imgs, step) #
    step += 1

writer.close()

代碼中主要容易的是這句話

for data_imgs, data_targets in dataloader_test:
    writer.add_images(f"epoch{epoch}", data_imgs, step) #
    step += 1

第一個, 我們接受的是兩個參數,一個是 data_imgs, 一個是 data_targets
writer.addImages,注意是 Images,因為添加的並不是一張圖片

附錄

幾個經常導入的 包

from torch.utils.data import Dataset
from torch.utils.tensorboard import SummaryWriter
from torchvision.transforms import transforms
from PIL import Image
import os
Author: Luckylight(xyg) Date: 2021/11/10


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM