在前兩篇我博客1.法寶函數、編譯器的初級使用和使用Dataset 和2. tensorboard和 transform的使用中,我分別介紹了 Dataset 和 transform 的簡單使用,並推薦使用了 pytorch 中常用的日志工具 tensorboard,在本篇博客中,我將繼續介紹 Dataset 和 Dataloader的使用,主要介紹數據的加載方式。
1. Datasets + transform
torch.utils.data.Dataset 的回顧
老辦法,我們還是先查看官方文檔
pytorch.org 上的介紹:
pycharm 上的介紹:
這兩種介紹大同小異,就不再二次翻譯了。
這個是我們之前學的 Dataset,直接進行繼承該類
手寫一個代碼看看:
from PIL import Image
from torch.utils.data import Dataset
import os
class MyDataset(Dataset):
def __init__(self, root_dir, label):
"""傳遞文件所在位置的參數"""
self.root_dir = root_dir
self.label = label
self.label_path = os.path.join(self.root_dir, self.label)
self.img_name_list = os.listdir(self.label_path)
def __getitem__(self, idx):
"""根據文件所在的位置,讀取文件后返回"""
img_path = os.path.join(self.label_path, self.img_name_list[idx])
return Image.open(img_path), self.label
def __len__(self):
return len(self.img_name_list)
if "__main__" == __name__:
my_data = MyDataset("./data/train", "ants")
img_data, label_data = my_data[0]
img_data.show()
transform 的回顧
transform
使我們剛剛學過的,就是調用 transform 庫中的幫助文檔,如果不太了解的話可以查看一下我的上一篇博客2. tensorboard和 transform的使用。
Datasets
請注意,這里是 Datasets 而不是 Dataset,這是一個復數!我們首先查看一下他的幫助文檔
下面,我們以 torchvision.datasets.CIFAR10
為例,進行展示使用方法
CIFAR10
幫助文檔
transform 和 target_transform 處理的對象不同,一個是 Image,一個是 image 對應的類別 target
CIFAR10的使用代碼
from torchvision import datasets
from torchvision import transforms
from PIL import Image
import cv2
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import Dataset
root_dir = "data_cifar10"
data_train = datasets.CIFAR10(root=root_dir, train=True, transform=transforms.ToTensor(), download=True)
data_test = datasets.CIFAR10(root=root_dir, train=False, transform=transforms.ToTensor(), download=True)
targets = []
img, target = data_test[0]
trans_to_PIL = transforms.ToPILImage()
trans_to_PIL(img).show()
for i in range(10):
targets.append(data_test[i][1])
print(targets)
debug 過程中, 查看 data_train 的屬性
建議
-
download=True
當我們使用 dataset的 時候,最好設置download=True
,通過設置參數為True
,使得我們給定的 root 文件夾不存在數值的時候,自動下載文件,倘若存在的時候,並不會下載文件;也就是說,download=True,代碼不容易出現一下奇奇怪怪的錯誤 -
資源下載
當我們下載文件比較慢的時候,可以通過自己手動下載,或者是借助一些下載工具完成下載,然后將其復制到我們的目標目錄下即可。這主要是考慮到使用一些下載器其實能夠起到一下加速下載的效果。
文件的下載鏈接也比較容易找到,如下圖所示:
或者是進入我們的pycharm 幫助文檔進行查找:
2. Dataloader
PyTorch
中數據讀取的一個重要接口是torch.utils.data.DataLoader
,主要用來將自定義的數據讀取接口的輸出或者PyTorch
已有的數據讀取接口的輸入按照batch size封裝成Tensor,后續只需要再包裝成Variable即可作為模型的輸入。
下面,我們打開官方文檔來一探究竟。
Dataloader 官方文檔
Dataloader 的使用
import cv2
from PIL import Image
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torch.utils.tensorboard import SummaryWriter
root_dir = "./data_cifar10"
dataset_train = datasets.CIFAR10(root=root_dir, train=True, transform=transforms.ToTensor(), download=True)
dataset_test = datasets.CIFAR10(root=root_dir, train=False, transform=transforms.ToTensor(), download=True)
dataloader_train = DataLoader(dataset=dataset_train, batch_size=64, shuffle=True, drop_last=True)
dataloader_test = DataLoader(dataset=dataset_test, batch_size=64, shuffle=True, drop_last=True)
log_dir = "logs"
writer = SummaryWriter(log_dir=log_dir)
for epoch in range(2):
step = 0
for data_imgs, data_targets in dataloader_test:
writer.add_images(f"epoch{epoch}", data_imgs, step) #
step += 1
writer.close()
代碼中主要容易的是這句話
for data_imgs, data_targets in dataloader_test:
writer.add_images(f"epoch{epoch}", data_imgs, step) #
step += 1
第一個, 我們接受的是兩個參數,一個是 data_imgs, 一個是 data_targets
writer.addImages,注意是 Images,因為添加的並不是一張圖片
附錄
幾個經常導入的 包
from torch.utils.data import Dataset
from torch.utils.tensorboard import SummaryWriter
from torchvision.transforms import transforms
from PIL import Image
import os