Pytorch數據讀取機制(DataLoader)與圖像預處理模塊(transforms)


Pytorch數據讀取機制(DataLoader)與圖像預處理模塊(transforms)

1.DataLoader

torch.utils.data.DataLoader():構建可迭代的數據裝載器, 訓練的時候,每一個for循環,每一次iteration,就是從DataLoader中獲取一個batch_size大小的數據的。

Dataloader()參數:

  • dataset: Dataset類,決定數據從哪讀取(數據路徑)以及如何讀取(做哪些預處理)
  • batchsize: 批大小
  • num_works: 是否采用多進程讀取機制
  • shuffle: 每一個epoch是否亂序
  • drop_last: 當樣本數不能被batchsize整除時,是否舍棄最后一批數據。

2. Dataset

torch.utils.data.Dataset():Dataset抽象類, 所有自定義的Dataset都需要繼承它,並且必須復寫__getitem__()這個類方法。

__getitem__方法的是Dataset的核心,作用是接收一個索引, 返回一個樣本, 看上面的函數,參數里面接收index,然后我們需要編寫究竟如何根據這個索引去讀取我們的數據部分。

2.1 ImageFolder

torchvision已經預先實現了常用的Dataset, 其他預先實現的有: torchvision.datasets.CIFAR10, 可以讀取CIFAR-10,以及ImageNet、COCO、MNIST、LSUN等數據集。

ImageFolder假設所有的文件按文件夾保存,每個文件夾下存儲同一個類別的圖片,文件夾名為類名,其構造函數如下:

ImageFolder(root, transform=None, target_transform=None, loader=default_loader)

參數:

  • root: 圖片路徑
  • transform: 對PIL Image進行的轉換操作,transform的輸入是使用loader讀取圖片的返回對象
  • target_transform:對label的轉換
  • loader:給定路徑后如何讀取圖片,默認讀取為RGB格式的PIL Image對象

示例:

文件夾格式:

image-20210720154912213

train_path = r'datasets/myDataSet/train'

預處理格式:

train_transform = transforms.Compose([
    transforms.Resize((40,40)),
    transforms.RandomCrop(40,padding=4),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],
                         [0.229,0.224,0.225],)
])

dataset:

trainset = ImageFolder(train_path,transform = train_transform)
# print(trainset[30]) # 元組類型,第30號圖片的(像素信息,label)

Data.DataLoader:

train_loader = Data.DataLoader(dataset=trainset, batch_size=4,shuffle=False)
for i,(img, target) in enumerate(train_loader):
    print(i)
    print(img.shape) # (batchsize, channel, H, W)
    print(target.shape) # (batch)
    print(target) # 一個batch圖片對應的label

2.2

class myDataset(Data.Dataset):
    def __init__(self, path, transform):
        self.path = path
        self.transform = transform
        self.data_info = self.get_img_info(path)
        self.label = []
        for i in range(len(self.data_info)):
            self.label.append(list(self.data_info[i])[1])

    def __getitem__(self, idx):
        path_img = self.data_info[idx][0]
        label = self.label[idx]
        img = Image.open(path_img).convert('RGB')  # 0~255
        if self.transform is not None:
            img = self.transform(img)  # 在這里做transform,轉為tensor等等
        return img, label, idx

    def __len__(self):
        return len(self.data_info)

    @staticmethod
    def get_img_info(data_dir):
        data_info = list()
        for root, dirs, _ in os.walk(data_dir):
            # 遍歷類別
            for sub_dir in dirs:
                img_names = os.listdir(os.path.join(root, sub_dir))
                img_names = list(filter(lambda x: x.endswith('.jpg'), img_names))
                # 遍歷圖片
                for i in range(len(img_names)):
                    img_name = img_names[i]
                    path_img = os.path.join(root, sub_dir, img_name)
                    label = int(sub_dir)
                    data_info.append((path_img, int(label)))
        return data_info

trainset = myDataset(train_path, train_transform)

train_loader = Data.DataLoader(dataset=trainset, batch_size=4,shuffle=True)
for i,(img, target, index) in enumerate(train_loader):
    print(i)
    print(img.shape) # (batchsize, channel, H, W)
    print(target.shape) # (batch)
    print(target) # 一個batch的圖片對應的label
    print(index) #  一個batch的圖片在數據集中對應的index

s


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM