Pytorch系列:(二)數據加載


DataLoader

DataLoader(dataset,batch_size=1,shuffle=False,sampler=None,
batch_sampler=None,num_workers=0,collate_fn=None,pin_memory=False,
drop_last=False,timeout=0,work_init_fn=None)

常用參數說明:

  • dataset: Dataset類 ( 詳見下文數據集構建 ),可以自定義數據集或者讀取pytorch自帶數據集

  • batch_size: 每個batch加載多少個樣本, 默認1

  • shuffle: 是否順序讀取,True表示隨機打亂,默認False

  • sampler:定義從數據集中提取樣本的策略。如果指定,則忽略shuffle參數。

  • batch_sampler: 定義一個按照batch_size大小返回索引的采樣器。采樣器詳見下文Batch_Sampler

  • num_workers: 數據讀取進程數量, 默認0

  • collate_fn: 自定義一個函數,接收一個batch的數據,進行自定義處理,然后返回處理后這個batch的數據。例如改變數據類型:

def my_collate_fn(batch_data):
    x_batch = []
    y_batch = []
    for x,y in batch_data:
        x_batch.append(x.float())
        y_batch.append(y.int())
    return x_batch,y_batch

  • pin_memory:設置pin_memory=True,則意味着生成的Tensor數據最開始是屬於內存中的鎖頁內存,這樣將內存的Tensor轉義到GPU的顯存就會更快一些。默認為False.

    主機中的內存,有兩種,一種是鎖頁,一種是不鎖頁。鎖頁內存存放的內容在任何情況下都不會與主機的虛擬內存 (硬盤)進行交換,而不鎖頁內存在主機內存不足時,數據會存放在虛擬內存中。注意顯卡中的顯存全部都是鎖業內存。如果計算機內存充足的話,設置為True可以加快數據交換順序。

  • drop_last:默認False, 最后剩余數據量不夠batch_size時候,是否丟棄。

  • timeout: 設置數據讀取的時間限制,超過限制時間還未完成數據讀取則報錯。數值必須大於等於0

數據集構建

自定義數據集

自定義數據集,需要繼承torch.utils.data.Dataset,然后在__getitem__()中,接受一個索引,返回一個樣本, 基本流程,首先在__init__()加載數據以及做一些處理,在__getitem__()中返回單個數據樣本,在__len__() 中,返回樣本數量

import torch
import torch.utils.data.dataset as Data 

class MyDataset(Data.Dataset):

    def __init__(self):
            self.x = torch.randn((10,20))
            self.y = torch.tensor([1 if i>5 else 0 for i in range(10)],
            dtype=torch.long)
            
    def __getitem__(self,idx):
        return self.x[idx],self.y[idx]
        
    def __len__(self):
        return self.x.__len__()
       

torchvision數據集

pytorch自帶torchvision庫可以幫助我們方便快捷的讀取和加載數據

import torch
from torchvision import datasets, transforms

# 定義一個預處理方法
transform = transforms.Compose([transforms.ToTensor()])
# 加載一個自帶數據集
trainset = datasets.MNIST('/pytorch/MNIST_data/', download=True, train=True, 
transform=transform)

TensorDataset

注意這里的tensor必須是一維度的數據。

import torch.utils.data as Data
x = torch.tensor([1,2,3,4,5])
y = torch.tensor([0,0,0,1,1])
dataset = Data.TensorDataset(x,y) 

從文件夾中加載數據集

如果想要加載自己的數據集可以這樣,用貓狗數據集舉例,根目錄下 ( "data/train" ),分別放置兩個文件夾,dog和cat,這樣使用ImageFolder函數就可以自動的將貓狗照片自動的按照文件夾定義為貓狗兩個標簽

import torch
from torchvision import datasets, transforms

data_dir = "data/train"
transform = transforms.Compose([transforms.Resize(255),transforms.ToTensor()])

dataset = datasets.ImageFolder(data_dir, transform=transform)


數據集操作

數據拼接

連接不同的數據集以構成更大的新數據集。

class torch.utils.data.ConcatDataset( [datasets, ... ] )

newDataset = torch.utils.data.ConcatDataset([dataset1,dataset2])

數據切分

方法一: class torch.utils.data.Subset(dataset, indices)

取指定一個索引序列對應的子數據集。

from torch.utils.data import Subset

train_set = Subset(dataset,[i for i in range(1,100)]
test_set = Subset(test0_ds,[i for i in range(100,150)]

方法二:torch.utils.data.random_split(dataset, lengths)

from torch.utils.data import random_split
train_set, test_set =  random_split(dataset,[100,50])

采樣器

所有采樣器都在 torch.utils.data 中,采樣器會根據該有的策略返回一組索引,在DataLoader中設定了采樣器之后,會根據索引讀取相應的樣本, 不同采樣器生成的索引不一樣,從而實現不同的采樣目的。

Sampler

所有采樣器的基類,自定義采樣器的時候需要實現 __iter__() 函數

class Sampler(object):
    """ 
    Base class for all Samplers.
    """
    
    def __init__(self, data_source):
        pass
    
    def __iter__(self):
        raise NotImplementedError

RandomSampler

RandomSampler,當DataLoader的shuffle參數為True時,系統會自動調用這個采樣器,實現打亂數據。默認的是采用SequentialSampler,它會按順序一個一個進行采樣。

SequentialSampler

按順序采樣,當DataLoader的shuffle參數為False時,使用的就是SequentialSampler。

SubsetRandomSampler

輸入一個列表,按照這個列表采樣。也可以通過這個采樣器來分割數據集。

BatchSampler

參數:sampler, batch_size, drop_last

每此返回batch_size數量的采樣索引,通過設置sampler參數來使用不同的采樣方法。

WeightedRandomSampler

參數:weights, num_samples, replacement

它會根據每個樣本的權重選取數據,在樣本比例不均衡的問題中,可用它來進行重采樣。通過weights 設定樣本權重,權重越大的樣本被選中的概率越大,待選取的樣本數目一般小於全部的樣本數目。num_samples 為返回索引的數量,replacement表示是否是放回抽樣,如果為True,表示可以重復采樣,默認為True

自定義采樣器

集成Sampler類,然后實現__iter__() 方法,比如,下面實現一個SequentialSampler類

class SequentialSampler(Sampler):
    r"""Samples elements sequentially, always in the same order.
    Arguments:
        data_source (Dataset): dataset to sample from
    """
   
    def __init__(self, data_source):
        self.data_source = data_source

    def __iter__(self):
        return iter(range(len(self.data_source)))

    def __len__(self):
        return len(self.data_source)


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM