什么是pytorch(4.數據集加載和處理)(翻譯)


數據集加載和處理

這里主要涉及兩個包:torchvision.datasets 和torch.utils.data.DatasetDataLoader

torchvision.datasets是一些包裝好的數據集

里邊所有可用的dataset都是 torch.utils.data.Dataset 的子類,這些子類都要有 __getitem__ __len__ 方法是實現。

這樣, 定義的數據集才能夠被 torch.utils.data.DataLoader ,DataLoader能夠使用torch.multiprocessing並行加載許多樣本

例如:

imagenet_data = torchvision.datasets.ImageFolder('path/to/imagenet_root/') data_loader = torch.utils.data.DataLoader(imagenet_data, batch_size=4, shuffle=True, num_workers=args.nThreads) 

當我們需要使用我們的數據集的時候,就需要進行包裝成DataLoader能夠識別的Dataset這樣就能把我們從無窮的數據預處理中解脫出來。
創建數據集
首先導入,創建一個子類:
from torch.utils.data import Dataset
import torch
class MyDateset(Dataset):
    def __init__(self,num=10000,transform=None):  #這里就可以寫你的參數了,比如文件夾什么的。
        self.len=num
        self.transform=transform
    def __len__(self):
        return self.len
    def __getitem__(self,idx):
        data=torch.rand(3,3,5)  #這里就是你的數據圖像的話就是C*M*N的tensor,這里創建了一個3*3*5的張量
        label=torch.LongTensor([1])   #label也是需要一個張量
        if self.transform:    #這里就是數據預處理的部分 、
            data=self.transform(data)  #處理完必須要返回torch.Tensor類型
        return data,label

下面我們測試一下:
md=MyDateset()
print(md[0])
print(len(md))
輸出:
(tensor([[[0.2753, 0.8114, 0.2916, 0.9600, 0.5057], [0.8595, 0.1195, 0.8065, 0.6393, 0.6213], 
[0.0997, 0.8590, 0.2469, 0.2158, 0.5296]], [[0.4764, 0.0561, 0.5866, 0.6129, 0.1882],
[0.4666, 0.9362, 0.5397, 0.3065, 0.4307], [0.4700, 0.6202, 0.3649, 0.6357, 0.5181]],
[[0.9794, 0.8127, 0.9842, 0.8821, 0.2447], [0.2320, 0.6406, 0.5683, 0.5637, 0.2734],
[0.2131, 0.5853, 0.5633, 0.9069, 0.9250]]]), tensor([1]))
10000
輸出:這樣我們就自定義了一個數據集Dataset,這樣我們需要使用已有的數據集的時候就可以知道torchvision.dataset下許多數據集的構成了。
 
預處理數據

返回來再看上邊定義數據集里有個參數transform,從定義getitem函數里看到,transform其實是一個函數。
torchvision.transforms里就包括了好多的操作。當然它主要處理的是圖像,就是C*H*W類型的舉證了。
可以直接這樣使用:
from torchvision import transforms

md=MyDateset(transform=transforms.Normalize((0,0,0),(0.1,0.2,0.3)))
print(md[0])
(tensor([[[2.5435, 9.1073, 4.1653, 9.4720, 0.7595],
         [0.4840, 7.2377, 3.1578, 4.5391, 2.7440],
         [4.6951, 4.7698, 1.1308, 0.5321, 3.5101]],

        [[2.6714, 4.5143, 0.0582, 0.2880, 0.2565],
         [2.2951, 0.0680, 0.3542, 4.7372, 2.0162],
         [1.4065, 2.5195, 0.8911, 4.8432, 3.1045]],

        [[2.7726, 2.5199, 0.8066, 0.7089, 2.0651],
         [1.8641, 1.6599, 0.5546, 2.8716, 2.0964],
         [2.5320, 1.5349, 1.8792, 0.0933, 3.2289]]]), tensor([1]))
更多的變換參見:https://pytorch.org/docs/master/torchvision/transforms.html

當然我們也可以自定義一個函數傳入:
def add1(x):
    return x+1
md=MyDateset(transform=add1)
print(md[0])
輸出:
(tensor([[[1.9552, 1.1294, 1.9435, 1.6476, 1.2726],
         [1.1544, 1.7726, 1.1975, 1.9914, 1.2694],
當然也可以組合起來個transform形成一個一個處理級聯:
tc=transforms.Compose([transforms.Normalize((0,0,0),(0.1,0.2,0.3)),add1])
md=MyDateset(transform=tc)
print(md[0])

輸出:
(tensor([[[ 1.9232,  6.4972,  7.9916,  4.3426, 10.9737],
         [ 5.4062,  2.6264,  6.8474,  4.7810,  3.3232],
         [ 8.6633,  4.1399,  2.3371,  5.5058,  3.9724]],
等等。


用Dataloader加載數據集

在訓練網絡,測試網絡時我們就需要使用剛才定義好的數據集了。

from torch.utils.data import Dataset, DataLoader
md=MyDateset()
print(md[1])
dl=DataLoader(md, batch_size=4,  shuffle=False,  num_workers=4)
print(len(dl.dataset))

這樣dl就可以在程序里循環生成批樣本,提供訓練,測試了。




免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM