利用torch.utils.data.Dataset自定義數據加載類


import torch as t
from torch.utils import data
import os
from PIL import Image
import numpy as np

import torchvision.transforms as T

 

transforms = T.Compose([

  T.Resize(224),

  T.CenterCrop(224),

  T.ToTensor(),

  T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))

])

 

# 繼承Dataset類要重寫__getitem__()和__len__()
class CatDog(data.Dataset):
  def __init__(self, root, transforms=None):

    # 臨時變量不用加self
    imgs = os.listdir(root)
    self.imgs = [os.path.join(root, img) for img in imgs]

    self.transforms = transforms

  def __getitem__(self, index):
    label = 1 if dog else 0

    data = Image.open(self.imgs[index])
    if self.transform:

      data = self.transform(data)
    return data, label

  def __len__(self):
    return len(self.imgs)


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM