數據集加載和處理
這里主要涉及兩個包:torchvision.datasets 和torch.utils.data.Dataset 和DataLoader
torchvision.datasets是一些包裝好的數據集
里邊所有可用的dataset都是 torch.utils.data.Dataset
的子類,這些子類都要有 __getitem__
和
__len__
方法是實現。
這樣, 定義的數據集才能夠被 torch.utils.data.DataLoader
,DataLoader能夠使用torch.multiprocessing
並行加載許多樣本
例如:
imagenet_data = torchvision.datasets.ImageFolder('path/to/imagenet_root/') data_loader = torch.utils.data.DataLoader(imagenet_data, batch_size=4, shuffle=True, num_workers=args.nThreads)
當我們需要使用我們的數據集的時候,就需要進行包裝成DataLoader能夠識別的Dataset這樣就能把我們從無窮的數據預處理中解脫出來。
創建數據集
首先導入,創建一個子類:
from torch.utils.data import Dataset
import torch
class MyDateset(Dataset):
def __init__(self,num=10000,transform=None): #這里就可以寫你的參數了,比如文件夾什么的。
self.len=num
self.transform=transform
def __len__(self):
return self.len
def __getitem__(self,idx):
data=torch.rand(3,3,5) #這里就是你的數據圖像的話就是C*M*N的tensor,這里創建了一個3*3*5的張量
label=torch.LongTensor([1]) #label也是需要一個張量
if self.transform: #這里就是數據預處理的部分 、
data=self.transform(data) #處理完必須要返回torch.Tensor類型
return data,label
import torch
class MyDateset(Dataset):
def __init__(self,num=10000,transform=None): #這里就可以寫你的參數了,比如文件夾什么的。
self.len=num
self.transform=transform
def __len__(self):
return self.len
def __getitem__(self,idx):
data=torch.rand(3,3,5) #這里就是你的數據圖像的話就是C*M*N的tensor,這里創建了一個3*3*5的張量
label=torch.LongTensor([1]) #label也是需要一個張量
if self.transform: #這里就是數據預處理的部分 、
data=self.transform(data) #處理完必須要返回torch.Tensor類型
return data,label
下面我們測試一下:
md=MyDateset()
print(md[0])
print(len(md))
print(md[0])
print(len(md))
輸出:
(tensor([[[0.2753, 0.8114, 0.2916, 0.9600, 0.5057], [0.8595, 0.1195, 0.8065, 0.6393, 0.6213],
[0.0997, 0.8590, 0.2469, 0.2158, 0.5296]], [[0.4764, 0.0561, 0.5866, 0.6129, 0.1882],
[0.4666, 0.9362, 0.5397, 0.3065, 0.4307], [0.4700, 0.6202, 0.3649, 0.6357, 0.5181]],
[[0.9794, 0.8127, 0.9842, 0.8821, 0.2447], [0.2320, 0.6406, 0.5683, 0.5637, 0.2734],
[0.2131, 0.5853, 0.5633, 0.9069, 0.9250]]]), tensor([1]))
10000
輸出:這樣我們就自定義了一個數據集Dataset,這樣我們需要使用已有的數據集的時候就可以知道torchvision.dataset下許多數據集的構成了。
預處理數據
返回來再看上邊定義數據集里有個參數transform,從定義getitem函數里看到,transform其實是一個函數。
torchvision.transforms里就包括了好多的操作。當然它主要處理的是圖像,就是C*H*W類型的舉證了。
可以直接這樣使用:
from torchvision import transforms
md=MyDateset(transform=transforms.Normalize((0,0,0),(0.1,0.2,0.3)))
print(md[0])
(tensor([[[2.5435, 9.1073, 4.1653, 9.4720, 0.7595], [0.4840, 7.2377, 3.1578, 4.5391, 2.7440], [4.6951, 4.7698, 1.1308, 0.5321, 3.5101]], [[2.6714, 4.5143, 0.0582, 0.2880, 0.2565], [2.2951, 0.0680, 0.3542, 4.7372, 2.0162], [1.4065, 2.5195, 0.8911, 4.8432, 3.1045]], [[2.7726, 2.5199, 0.8066, 0.7089, 2.0651], [1.8641, 1.6599, 0.5546, 2.8716, 2.0964], [2.5320, 1.5349, 1.8792, 0.0933, 3.2289]]]), tensor([1]))
更多的變換參見:https://pytorch.org/docs/master/torchvision/transforms.html
當然我們也可以自定義一個函數傳入:
def add1(x):
return x+1
md=MyDateset(transform=add1)
print(md[0])
輸出:
(tensor([[[1.9552, 1.1294, 1.9435, 1.6476, 1.2726], [1.1544, 1.7726, 1.1975, 1.9914, 1.2694],
當然也可以組合起來個transform形成一個一個處理級聯:
tc=transforms.Compose([transforms.Normalize((0,0,0),(0.1,0.2,0.3)),add1])
md=MyDateset(transform=tc)
print(md[0])
輸出:
(tensor([[[ 1.9232, 6.4972, 7.9916, 4.3426, 10.9737], [ 5.4062, 2.6264, 6.8474, 4.7810, 3.3232], [ 8.6633, 4.1399, 2.3371, 5.5058, 3.9724]],
等等。
用Dataloader加載數據集
在訓練網絡,測試網絡時我們就需要使用剛才定義好的數據集了。
from torch.utils.data import Dataset, DataLoader
md=MyDateset()
print(md[1])
dl=DataLoader(md, batch_size=4, shuffle=False, num_workers=4)
print(len(dl.dataset))
這樣dl就可以在程序里循環生成批樣本,提供訓練,測試了。