在深度學習中,數據的處理對於神經網絡的訓練來說十分重要,良好的數據(包括圖像、文本、語音等)處理不僅可以加速模型的訓練,同時也直接關系到模型的效果。本文以處理圖像數據為例,記錄一些使用PyTorch進行圖像預處理和數據加載的方法。
一、數據的加載
在PyTorch中,數據加載需要自定義數據集類,並用此類來實例化數據對象,實現自定義的數據集需要繼承torch.utils.data包中的Dataset類。
在繼承Dataset實現自己的類時,需要實現以下兩個Python魔法方法:
- __getitem__(index): 返回一個樣本數據,當使用obj[index]時實際就是在調用obj.__getitem__(index)
- __len__():返回樣本的數量,當使用len(obj)時實際就是在調用obj.__len__()
例如,以貓狗大戰的二分類數據集為例,其加載過程如下:
import os
import torch as t
from torch.utils import data
from PIL import Image
import numpy as np
class dogCat(data.Dataset):
def __init__(self,root): # root為數據存放目錄
imgs = os.listdir(root) #列出當前路徑下所有的文件
self.imgs = [os.path.join(root,img) for img in imgs] # 所有圖片的路徑
#print(self.imgs)
"""返回一個樣本數據"""
def __getitem__(self, item):
img_path = self.imgs[item] # 第item張圖片的路徑
#dog 1 cat 0
label = 1 if 'dog' in img_path.split('\\')[-1] else 0 # 獲取標簽信息
#print(label)
pil_img = Image.open(img_path) #讀入圖片
print(type(pil_img))
array = np.asarray(pil_img) # 轉為numpy.array類型
data = t.from_numpy(array) # 轉為tensor類型
return data,label #返回圖片對應的tensor及其標簽
"""樣本的數量"""
def __len__(self):
return len(self.imgs)
if __name__ == '__main__':
dogcat = dogCat('D:\pycode\dogsVScats\data\catvsdog\\train') #數據集對象
data,label = dogcat[0] # 返回第0張圖片的信息
print(data.size())
print(label)
print(len(dogcat))
二、計算機視覺工具包:torchvision
對於圖像數據來說,以上的數據加載時不完善的,因為只是將圖片讀入,而沒有進行相關的處理,如每張圖片的大小和形狀,樣本的數值歸一化等等。
為了解決這一問題,PyTorch開發了一個視覺工具包torchvision,這個包獨立於torch,需要通過pip install torchvision
來單獨安裝。
torchvision有三個部分組成:
- models:提供各種經典的網絡結構和預訓練好的模型,如AlexNet、VGG、ResNet、Inception等。
from torchvision import models
from torch import nn
resnet34 = models.resnet34(pretrained=True,num_classes=1000) # 加載預訓練模型
resnet34.fc=nn.Linear(512,10) # 修改全連接層為10分類
- datasets:提供了常用的數據集,如MNIST、CIFAR10/100、ImageNet、COCO等。
from torchvision import datasets
dataset = datasets.MNIST('data/',download=True,train=False,transform=transform)
除了常用數據集外,需要特別注意的是ImageFolder,ImageFolder假設所有的文件按文件夾存放,每個文件夾下面存儲同一類的圖片,文件夾的名字為這一類別的名字。這是我們經常用到的一種數據組織形式。
# 使用方法:
ImageFolder(root,transform=None,target_transform=None,loader=default_loader)
# 參數:文件夾路徑,對圖像做什么樣的轉換,對標簽做什么樣的轉換,如何加載圖片
from torchvision.datasets import ImageFolder
dataset = ImageFolder('data\\')
print(dataset.class_to_idx) # class_to_idx ,label和id的對應關系,從0開始
print(dataset.imgs) # 數據和標簽對應
- transforms: 提供常用的數據預處理操作,主要是對Tensor和PIL Image對象的處理操作。
對PIL Image的操作:Resize、CenterCrop、RandomCrop、RandomsizedCrop、Pad、ToTensor等。
對Tensor的操作:Normalize、ToPILImage等。
如果要進行多個操作,可以通過transforms.Compose([])將操作拼接起來。但是需要注意的是需要首先構建轉換操作,然后再執行轉換操作。
import os
from torch.utils import data
from PIL import Image
import numpy as np
from torchvision import transforms as T
transform = T.Compose([T.Resize(224),T.CenterCrop(224),T.ToTensor(),T.Normalize(mean=[.5,.5,.5],std=[.5,.5,.5])]) # 構建轉換操作
class dogCat(data.Dataset):
def __init__(self,root,transforms):
imgs = os.listdir(root)
#print(imgs)
self.imgs = [os.path.join(root,img) for img in imgs]
#print(self.imgs)
self.transforms = transforms
def __getitem__(self, item):
img_path = self.imgs[item]
#dog 1 cat 0
label = 1 if 'dog' in img_path.split('\\')[-1] else 0
#print(label)
pil_img = Image.open(img_path)
if self.transforms:
pil_img = self.transforms(pil_img) #執行准換操作
return pil_img,label,item
def __len__(self):
return len(self.imgs)
三、使用DataLoader進行數據再處理
通過上述描述,我們通過自定義數據集類,使用視覺工具包進行圖像的轉換等操作,最終得到的是一個dataset的數據集對象,使用此對象可以一次返回一個樣本。
但是,我們應該清楚:訓練神經網絡時,一般采用的是小批量的梯度下降,因此我們是對一批數據進行處理,也就是一個batch,同時,數據還需要進行打亂(shuffle)和並行加速等。PyTorch提供了DataLoader來實現這些功能。
DataLoader定義如下:
DataLoader(dataset,batch_size=1,shuffle=False,sampler=None,num_workers=0,collate_fn=default_collate,pin_memory=False,drop_last=False)
參數含義如下:
- dataset:加載的數據集
- batch_zize: 批大小
- shuffle: 是否將數據打亂
- sampler:樣本抽樣,常用的有隨機采樣RandomSampler,shuffle=True時自動調用隨機采樣,默認是順序采樣,還有一個常用的是:WeightedRandomSampler,按照樣本的權重進行采樣。
- num_workers: 使用的進程數,0代表不使用多進程。
- collate_fn: 拼接方式。
- pin_memory: 是否將數據保存在pin memory區。
- drop_last: 是否將多出來的不足一個batch的丟棄。
調用DataLoader得到的結果是一個可迭代的對象,可以和使用迭代器一樣使用它。
from torchvision import transforms as T
from torch.utils.data import DataLoader
transform = T.Compose([T.Resize(224),T.CenterCrop(224),T.ToTensor(),T.Normalize(mean=[.5,.5,.5],std=[.5,.5,.5])])
if __name__ == '__main__':
dogcat = dogCat('D:\pycode\dogsVScats\data\catvsdog\\train', transform)
data, label, index = dogcat[0]
dataloader = DataLoader(dogcat,batch_size=3,shuffle=False,num_workers=0,drop_last=False)
for batchDatas,batchLabels in dataloader:
train()
總結
本文記錄了使用PyTorch進行數據預處理的相關操作流程,重點是掌握Dataset和DataLoader兩個類的使用,另外,視覺工具包torchvision的三個模塊靈活運用,會對數據處理過程有很好的幫助。