本節講述Data如何利用Pytorch提供的DataLoader進行讀取,以及Transforms的圖片處理方式。 【文中思維導圖采用MindMaster軟件】 注意:籠統總結Transforms,目前僅具體介紹裁剪、翻轉、標准化,后續隨着代碼需要,再逐步更新。 |
一. 數據讀取(DataLoader和Dataset)
1.DataLoader
我們采用Pytorch提供的DataLoader進行數據Batch封裝,其中需要定義dataset類。
自定義的dataset類需要復寫def getitem(self, index):函數!!!
train_loader = DataLoader(dataset=train_dataset,
batch_size=Batch_Size,
shuffle=True)
for epoch in range(Max_Epoch):
for i, (inputs, labels) in enumerate(train_loader): # 每次調用一個batch,后台索引
# 也可以采用next(iter(train_loader)), 讀取一個批次
在網絡運行時,我們采用enumerate函數,進行迭代,這里會:
-
進入DataLoader數據裝載器;
-
判斷參數,是否采用多進程處理;
-
調用Sampler函數,根據輸入數據個數(由Dataset類中def len()函數得到),隨機獲取index索引值;
-
進入我們定義的Dataset類,調用def getitem(),根據index獲取數據,返回;
-
調用collate_fn()函數整理數據,最終得到Batch。
2.代碼(如何將電腦中的數據送入網絡?)
注意:這里數據集已經分類好,文件夾已經各自建立,不包含划分數據的函數!!
import torch
from torch.utils.data import Dataset
import os
from PIL import Image
import numpy as np
import torchvision.transforms as transforms
category = {"0": 0, "1": 1, "1_enhanced": 2, "1_enhanced_2": 3, "0_enhanced_1":4} # 定義標簽,"文件夾名":標簽
class my_dataset(Dataset):
'''根據自己的數據,進行讀取,Dataset類創建Pytorch數據集類型'''
'''
Args:
data_dir: 數據地址(訓練集、驗證集、測試集)
transform: torchvision.transforms(各種變換、以及Totensor)
Return:
read_data 根據dataloader的索引獲取數據
len(self.data_info) 數據個數
'''
def __init__(self, data_dir, transform=None):
self.transforms = transform
self.data_info = self.get_dataset_info(data_dir) # 獲取所有數據路徑和對應的標簽,方便dataloader 用index批量處理
def __getitem__(self, index): # 當dataloader sampler得到index,根據該index索引dataset中數據
path_data, label = self.data_info[index]
read_data = Image.open(path_data).convert("RGB") # PIL-->RGB(0-256)
if self.transforms is not None:
read_data = self.transforms(read_data)
return read_data, label
def __len__(self):
return len(self.data_info)
@staticmethod # 定義該函數為靜態類型,不用實例化類也可調用
def get_dataset_info(data_dir):
data_info = list() # 最終包含所有圖片、標簽(每一行)
for root, dirs, files in os.walk(data_dir): # 獲取當前文件夾的父目錄、當前文件夾下所有文件名、所有內部文件
for sub_dir in dirs: # 遍歷所有類別
each_cate = os.listdir(os.path.join(root, sub_dir)) # os.listdir() 方法用於返回指定的文件夾包含的文件或文件夾的名字的列表。
for i in range(len(each_cate)): # 遍歷每一個類別下的圖片數據,將標簽一同嵌入
each_data_name = each_cate[i]
each_data_path = os.path.join(root, sub_dir, each_data_name)
each_label = category[sub_dir]
data_info.append((each_data_path, int(each_label)))
return data_info
二.數據預處理(torchvision.transforms)
1.torchvision
2.transforms.Compose([......])組合
計算機將按照Compose中定義的transforms操作,依次進行數據處理。
train_transforms = transforms.Compose([
transforms.Resize((75, 75)),
transforms.ToTensor(), # (H x W x C) [0, 255] to a torch.FloatTensor (C x H x W) [0.0, 1.0]
transforms.Normalize(mean=norm_mean,std=norm_std) # 逐通道歸一化,注意通道數
])
3.各種transforms處理方式
本節目前僅介紹:標准化Normalize、圖像裁剪Crop、旋轉翻轉。
(1)數據標准化
transforms.Normalize(mean, std, inplace=False) #逐通道對圖像進行標准化,mean:(M1,...,Mn) and std: (S1,..,Sn) for n channels
# input[channel] = (input[channel] - mean[channel]) / std[channel]
(2)裁剪
a)從中心進行裁剪
transforms.CenterCrop(size=32) # 由圖像中心進行裁剪,size=32*32
b)隨機裁剪
transforms.RandomCrop(size, padding=None, pad_if_needed=False, fill=0, padding_mode='constant')
# 先填充再隨機裁剪
# padding:設置填充大小,數值a --> 上下左右填充a個像素,(a,b)--> 左右a上下b, (a,b,c,d) --> 左a上b右c下d
# padding_mode:填充模式:
# constant:像素值由fill參數設定;
# edge:由圖像邊緣像素決定;
# reflect:鏡像填充,最后一個像素不鏡像;
# symmetric:鏡像填充,最后一個像素鏡像。
c)隨機面積、隨機長寬比裁剪圖片
transforms.RandomResizedCrop(size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolation=Image.BILINEAR)
# 先選擇scale,再ratio,再判斷size,是否需要interpolation進行resized
# scale=(0.08, 1.0):隨機裁剪面積比例,范圍內隨機選
# ratio=(3. / 4., 4. / 3.):隨機長寬比
# interpolation:插值方法
d)上下左右中心隨機裁剪5張圖片
transforms.FiveCrop(size) # 從上下左右中心各裁剪出五張圖片
transforms.TenCrop(size, vertical_flip=False) # 先進行FiveCrop(),再對五張圖片進行水平/垂直鏡像,獲得10張圖片
注意:這里返回的是tuple()類型,需要按行拼接起來,送入下游transforms處理。
>>> transform = Compose([
>>> TenCrop(size), # this is a list of PIL Images
>>> Lambda(lambda crops: torch.stack([ToTensor()(crop) for crop in crops])) # returns a 4D tensor
>>> ])
>>> #In your test loop you can do the following:
>>> input, target = batch # input is a 5d tensor, target is 2d
>>> bs, ncrops, c, h, w = input.size()
>>> result = model(input.view(-1, c, h, w)) # fuse batch size and ncrops
>>> result_avg = result.view(bs, ncrops, -1).mean(1) # avg over crops
有問題:當采用數據增強時,一方面采用TenCrop形式,另一方面采用其他數據變換,一同送入Dataloader時會產生錯誤,因為維度不一致,其他數據變換在dataset中為三維(channel,H,W),而TenCrop卻是四維(ncrops,channel,H,W),於是當迭代獲取Batch時會由於維度不匹配程序報錯。
解決方法:【等后續找到再來寫,手動狗頭微笑】
(3)翻轉、旋轉
transforms.RandomHorizontalFlip(p=0.5) # 依概率進行水平(左右)翻轉
transforms.RandomVerticalFlip(p=0.5) # 依概率進行垂直(上下)翻轉
transforms.RandomRotation(degrees, resample=False, expand=False, center=None) # 隨機旋轉圖片
# degrees:旋轉角度,若為a,則在(-a,a)之間二選一,若為(a, b),則二選一
# expand:是否擴大圖片(因為旋轉過后可能會丟失圖片某一塊),僅針對中心點旋轉
# center:旋轉點設置,默認中心點
(4)對各種變換的組合--》選擇操作(如RandomChoice)
transforms.RandomChoice([transforms1, transforms2, ......]) # 隨機挑選一個
transforms.RandomApply([transforms1, transforms2, ......], p=0.5) # 依概率執行整個一組(要么執行,要么不執行)
transforms.RandomOrder([transforms1, transforms2, ......]) # 對一組操作進行打亂順序,再去執行這一組
4.自定義Transforms方法
class YourTransforms(object):
def __init__(self,Arg1,Arg2):
'''傳參數'''
def __call__(self, x):
'''定義該Transforms方法'''
return x