Pytorch框架學習---(2)輸入數據操作


本節講述Data如何利用Pytorch提供的DataLoader進行讀取,以及Transforms的圖片處理方式。 【文中思維導圖采用MindMaster軟件】

注意:籠統總結Transforms,目前僅具體介紹裁剪、翻轉、標准化,后續隨着代碼需要,再逐步更新。

一. 數據讀取(DataLoader和Dataset)

1.DataLoader

  我們采用Pytorch提供的DataLoader進行數據Batch封裝,其中需要定義dataset類。

自定義的dataset類需要復寫def getitem(self, index):函數!!!

train_loader = DataLoader(dataset=train_dataset,
                          batch_size=Batch_Size,
                          shuffle=True)
for epoch in range(Max_Epoch):
    for i, (inputs, labels) in enumerate(train_loader):  # 每次調用一個batch,后台索引
# 也可以采用next(iter(train_loader)), 讀取一個批次

  在網絡運行時,我們采用enumerate函數,進行迭代,這里會:

  • 進入DataLoader數據裝載器;

  • 判斷參數,是否采用多進程處理;

  • 調用Sampler函數,根據輸入數據個數(由Dataset類中def len()函數得到),隨機獲取index索引值;

  • 進入我們定義的Dataset類,調用def getitem(),根據index獲取數據,返回;

  • 調用collate_fn()函數整理數據,最終得到Batch。

2.代碼(如何將電腦中的數據送入網絡?)

注意:這里數據集已經分類好,文件夾已經各自建立,不包含划分數據的函數!!

import torch
from torch.utils.data import Dataset
import os
from PIL import Image
import numpy as np
import torchvision.transforms as transforms

category = {"0": 0, "1": 1, "1_enhanced": 2, "1_enhanced_2": 3, "0_enhanced_1":4}  # 定義標簽,"文件夾名":標簽

class my_dataset(Dataset):
    '''根據自己的數據,進行讀取,Dataset類創建Pytorch數據集類型'''
    '''
    Args:
        data_dir: 數據地址(訓練集、驗證集、測試集)
        transform: torchvision.transforms(各種變換、以及Totensor)      
    Return:
        read_data  根據dataloader的索引獲取數據
        len(self.data_info)  數據個數
    '''

    def __init__(self, data_dir, transform=None):
        self.transforms = transform
        self.data_info = self.get_dataset_info(data_dir)  # 獲取所有數據路徑和對應的標簽,方便dataloader 用index批量處理

    def __getitem__(self, index):  # 當dataloader sampler得到index,根據該index索引dataset中數據
        path_data, label = self.data_info[index]
        read_data = Image.open(path_data).convert("RGB")  # PIL-->RGB(0-256)

        if self.transforms is not None:
            read_data = self.transforms(read_data)

        return read_data, label

    def __len__(self):
        return len(self.data_info)

    @staticmethod  # 定義該函數為靜態類型,不用實例化類也可調用
    def get_dataset_info(data_dir):
        data_info = list()  # 最終包含所有圖片、標簽(每一行)
        for root, dirs, files in os.walk(data_dir):  # 獲取當前文件夾的父目錄、當前文件夾下所有文件名、所有內部文件
            for sub_dir in dirs:  # 遍歷所有類別
                each_cate = os.listdir(os.path.join(root, sub_dir))  # os.listdir() 方法用於返回指定的文件夾包含的文件或文件夾的名字的列表。

                for i in range(len(each_cate)):  # 遍歷每一個類別下的圖片數據,將標簽一同嵌入
                    each_data_name = each_cate[i]
                    each_data_path = os.path.join(root, sub_dir, each_data_name)
                    each_label = category[sub_dir]

                    data_info.append((each_data_path, int(each_label)))

        return data_info

二.數據預處理(torchvision.transforms)

1.torchvision

2.transforms.Compose([......])組合

  計算機將按照Compose中定義的transforms操作,依次進行數據處理。

train_transforms = transforms.Compose([
    transforms.Resize((75, 75)),
    transforms.ToTensor(),  # (H x W x C) [0, 255] to a torch.FloatTensor (C x H x W) [0.0, 1.0]
    transforms.Normalize(mean=norm_mean,std=norm_std)  # 逐通道歸一化,注意通道數
])

3.各種transforms處理方式

  本節目前僅介紹:標准化Normalize、圖像裁剪Crop、旋轉翻轉。

(1)數據標准化

transforms.Normalize(mean, std, inplace=False)  #逐通道對圖像進行標准化,mean:(M1,...,Mn) and std: (S1,..,Sn) for n channels
# input[channel] = (input[channel] - mean[channel]) / std[channel]

(2)裁剪

a)從中心進行裁剪

transforms.CenterCrop(size=32)  # 由圖像中心進行裁剪,size=32*32

b)隨機裁剪

transforms.RandomCrop(size, padding=None, pad_if_needed=False, fill=0, padding_mode='constant')
# 先填充再隨機裁剪
# padding:設置填充大小,數值a --> 上下左右填充a個像素,(a,b)--> 左右a上下b, (a,b,c,d) --> 左a上b右c下d
# padding_mode:填充模式:
      # constant:像素值由fill參數設定;
      # edge:由圖像邊緣像素決定;
      # reflect:鏡像填充,最后一個像素不鏡像;
      # symmetric:鏡像填充,最后一個像素鏡像。

c)隨機面積、隨機長寬比裁剪圖片

transforms.RandomResizedCrop(size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolation=Image.BILINEAR)
# 先選擇scale,再ratio,再判斷size,是否需要interpolation進行resized
# scale=(0.08, 1.0):隨機裁剪面積比例,范圍內隨機選
# ratio=(3. / 4., 4. / 3.):隨機長寬比
# interpolation:插值方法

d)上下左右中心隨機裁剪5張圖片

transforms.FiveCrop(size)  # 從上下左右中心各裁剪出五張圖片
transforms.TenCrop(size, vertical_flip=False)  # 先進行FiveCrop(),再對五張圖片進行水平/垂直鏡像,獲得10張圖片

注意:這里返回的是tuple()類型,需要按行拼接起來,送入下游transforms處理。

>>> transform = Compose([
         >>>    TenCrop(size), # this is a list of PIL Images
         >>>    Lambda(lambda crops: torch.stack([ToTensor()(crop) for crop in crops])) # returns a 4D tensor
         >>> ])
         >>> #In your test loop you can do the following:
         >>> input, target = batch # input is a 5d tensor, target is 2d
         >>> bs, ncrops, c, h, w = input.size()
         >>> result = model(input.view(-1, c, h, w)) # fuse batch size and ncrops
         >>> result_avg = result.view(bs, ncrops, -1).mean(1) # avg over crops

有問題:當采用數據增強時,一方面采用TenCrop形式,另一方面采用其他數據變換,一同送入Dataloader時會產生錯誤,因為維度不一致,其他數據變換在dataset中為三維(channel,H,W),而TenCrop卻是四維(ncrops,channel,H,W),於是當迭代獲取Batch時會由於維度不匹配程序報錯。
解決方法:【等后續找到再來寫,手動狗頭微笑】

(3)翻轉、旋轉

transforms.RandomHorizontalFlip(p=0.5)  # 依概率進行水平(左右)翻轉
transforms.RandomVerticalFlip(p=0.5)  # 依概率進行垂直(上下)翻轉
transforms.RandomRotation(degrees, resample=False, expand=False, center=None)  # 隨機旋轉圖片
      # degrees:旋轉角度,若為a,則在(-a,a)之間二選一,若為(a, b),則二選一
      # expand:是否擴大圖片(因為旋轉過后可能會丟失圖片某一塊),僅針對中心點旋轉
      # center:旋轉點設置,默認中心點

(4)對各種變換的組合--》選擇操作(如RandomChoice)

transforms.RandomChoice([transforms1, transforms2, ......])  # 隨機挑選一個
transforms.RandomApply([transforms1, transforms2, ......], p=0.5)  # 依概率執行整個一組(要么執行,要么不執行)
transforms.RandomOrder([transforms1, transforms2, ......])  # 對一組操作進行打亂順序,再去執行這一組

4.自定義Transforms方法

class YourTransforms(object):
    def __init__(self,Arg1,Arg2):
        '''傳參數'''
    def __call__(self, x):
        '''定義該Transforms方法'''
        return x


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM