Pytorch——Dataset類和DataLoader類


  這篇文章主要探討一下,Dataset類以及DataLoader類的使用以及注意事項。Dataset類主要是用於原始數據的讀取或者基本的數據處理(比如在NLP任務中常常需要把文字轉化為對應字典ids,這個步驟就可以放在Dataset中執行)。DataLoader,是進一步對Dataset的處理,Dataset得到的數據集你可以理解為是個"列表"(可以根據index取出某個特定位置的數據),而DataLoder就是把這個數據集(Dataset)根據你設定的batch_size划分成很多個“子數據集”,每個“子數據集”中的元素數量就是batch_size。

  DataLoader為什么要把Dataset划分成多個”子數據集“呢?因為一次性把所有的數據放進模型會導致內存溢出,而且模型的迭代會很慢。下面我們就深度解析下Dataset和DataLoader的使用方式。

一、Dataset的使用

  這里說到的Dataset其實就是,torch.utils.data.Dataset類 ,換句話說我們需要創建一個Dataset類,使用類的繼承就可以了。既然是繼承類,那么肯定會修改一些父類(torch.utils.data.Dataset類 )的方法來適應我們的真實數據和邏輯。而我們主要要重寫的就是,__init__()__len__()__getitem__(),這三個方法分別是以下作用:

  • __init__方法:進行類的初始化,一般是用來讀取原始數據。
  • __getitem__方法:根據下標對每一個數據進行進一步的處理。return:希望通過dataset[index]在數據集中取出的元素
  • __len__方法:return:數據集的數量(int)

  下面用一個例子來大致說明下Dataset該怎么構建,並且如何使用。

 

from torch.utils.data import Dataset 
import torch

def MyTokenizer(sentence):
    src_vocab = {'度':0,'上':1,'世':2,'中':3,
    '為':4,'人':5,'偉':6,'你':7,'務':8,'國':9,
    '大':10,'我':11,'是':12,'最':13,'服':14,
    '民':15,'愛':16,'界':17,'的':18}
    enc_input = [src_vocab[n] for n in sentence]
    if len(enc_input) < 12:                     ## 如果enc_input的長度小於12,則用100來補足,使得enc_input長度為12.
        enc_input = enc_input+(12-len(enc_input))*[100]
    return enc_input


class MyDataset(Dataset):
    def __init__(self, data):
        self.data = data
        self.tokenizer = MyTokenizer

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        sentence = self.data[index]
        return torch.tensor(self.tokenizer(sentence))

data = ['我愛你中國', '中國是世界上最偉大的國度', '為人民服務','你愛我']
dataset = MyDataset(data)

 

  簡單介紹一下,函數MyTokenizer(sentence):把一個句子根據字典,轉化為一個列表[...],例如”我愛你“——>[11, 16, 7, 100,100,...100]。這里為了簡便,我是用的數據就是四句話。我在初始化(__init__)我的MyDataset類時,把數據儲存下來,並且定義了我的編碼器Mytokenizer。(這里為什么輸出列表末尾會有100呢,主要是使得每一個數據長度是一樣的,為后面進入DataLoader做准備,其實這個操作就叫做padding)

  __len__(self):這里返回了我傳入數據集的大小。而__getitem__(self,index):中index指的是數據下標,根據這個下標提取出原始數據(self.data中的一句話),並且把這句話傳入到self.tokenizer進行編碼,最后返回編碼的結果(一個列表[....]),可以看到__getitem__這個函數就是根據index來處理每一個拿出來的原始數據的,你對原始數據的所有處理都可以放在這里。我們最后一行代碼是完成了MyDataset的實例化。看一看這個實例化之后的結果。

print(len(dataset))
print(dataset[0])
3
[11, 16, 7, 3, 9]

  這里返回的tensor([11,16,7,3,9]),其實就是”我愛你中國“經過編碼之后的結果(可以對照上面的字典看看)。

  其實說白了Dataset就是一個數據處理器,把數據收集起來,並且進行對每一個index的數據進行處理,最后輸出。有人會問為啥不先處理好這些數據呢?其實是因為DataLoader只能接受torch.utils.data.Dataset類作為傳入參數,因此用其他任意的數據結構都沒辦法放到DataLoader里面,這樣就沒法自動根據batch_size拆分成”子數據集“。因此Dataset是我們必須構建的,就算是我的數據不想進一步處理,也必須寫一個以上的最簡單的MyDataset類(直接傳入啥,輸出啥的類)。

二、DataLoader的使用

  先放官方文檔:官方文檔

  剛才說到Dataset的構建是為了放進到DataLoader里,為啥非要放到這里面呢?其根本原因是DataLoader中有很多好用的設置可以讓我們更好的處理數據,比如參數shuffle,可以讓Dataset中的數據打亂重新排列再進行分批次,num_workers參數可以設定安排多少個進程來加載數據(加速)。一般情況下我們不需要重寫DataLoader類,只需要實例化就可以了。例如我們把上面創建好的Dataset實例——dataset傳入到DataLoader中構建實例。

  這里一定要注意,每個batch(子集)里的長度一定要一致,不然會報錯“RuntimeError: each element in list of batch should be of equal size”。(這也就是為什么,在建立Dataset的時候我會用100來吧不足12長度的句子填充成統一長度,因為我舉的例子中沒有超過12的句子,所以不存在切割句子,真實情況需要按你自己的數據需求,但是一定要保證出來的數據要一樣長,至於為什么一會后面說)。

from torch.utils.data import DataLoader
myDataloader = DataLoader(dataset, shuffle=True, batch_size=2)

  這個myDataloader就是DataLoader的實例,已經被分為了2個數據為一個batch,接下來我們打印一下每個batch(由於我們只有4句話,2個樣本為一個batch那么其實就只有2個batch,所以可以打印來看看)。

for batch in myDataloader:
    print(batch)
    print('===============================')
tensor([[  7,  16,  11, 100, 100, 100, 100, 100, 100, 100, 100, 100],
        [  4,   5,  15,  14,   8, 100, 100, 100, 100, 100, 100, 100]])
===============================
tensor([[  3,   9,  12,   2,  17,   1,  13,   6,  10,  18,   9,   0],
        [ 11,  16,   7,   3,   9, 100, 100, 100, 100, 100, 100, 100]])
===============================

  可以看到每個batch其實是一個tensor,維度是(2,12)。每個tensor的每一行其實就是一個dataset里的一個樣本。並且要注意每個樣本已經不是按照原本的順序排列了。

 

三、collate_fn參數的使用

  在DataLoader里,除了上面提到的shuffle參數和batch_size參數以外,還有一個非常重要的傳入參數collate_fn,這個參數傳入的是一個函數,這個函數主要是對每個batch進行處理,最終輸出一個batch的返回值,換句話說collate_fn函數的返回值,就是遍歷DataLoader的時候每個“batch”的返回值了(類似於上面例子中的二維tensor)。下面我寫一個函數,讓大家看看到底是怎么處理的。

def mycollate(item):
    sample1, sample2 = item
    return {'第一個樣本':sample1,'第二個樣本':sample2}

from torch.utils.data import DataLoader
myDataloader = DataLoader(dataset, shuffle=True, batch_size=2, collate_fn=mycollate)

  我們現在再來打印一下myDataloader的每個元素。

for batch in myDataloader:
    print(batch)
    print('===============================')
{'第一個樣本': tensor([ 11,  16,   7,   3,   9, 100, 100, 100, 100, 100, 100, 100]), '第二個樣本': tensor([  7,  16,  11, 100, 100, 100, 100, 100, 100, 100, 100, 100])}
===============================
{'第一個樣本': tensor([ 3,  9, 12,  2, 17,  1, 13,  6, 10, 18,  9,  0]), '第二個樣本': tensor([  4,   5,  15,  14,   8, 100, 100, 100, 100, 100, 100, 100])}
===============================

  可以看到,這個時候打印myDataloader的每個元素,就變成我在mycollate()函數中的返回值了。或許會不明白,我在mycollate()函數中這個傳入的item是什么?其實這個item是個元組,元組的每個元素就是dataset的每個元素(tensor([3,9,12,....])),item的元素個數其實就是batch_size,這里的batch_size是2,所以我在mycollate()中用了兩個變量來接收(換句話說要是我把batch_size換成2以外的其他數字,就會報錯了)。

  可以看到其實我們在DataLoader的時候依然可以使用函數來處理我們的數據,換句話說我們完全可以把tokenizer函數放到mycollate()函數中。

  好了現在我們可以來解釋為什么我在第二節的時候要求每個batch的數據要一樣長了,那是因為當你不給定collate_fn這個參數的時候,會自動調用一個函數叫做default_collate(),大家可以粗略的看看這個內置函數的源碼:

def default_collate(batch):
    r"""Puts each data field into a tensor with outer dimension batch size"""

    elem = batch[0]
    elem_type = type(elem)
    if isinstance(elem, torch.Tensor):
        out = None
        if torch.utils.data.get_worker_info() is not None:
            # If we're in a background process, concatenate directly into a
            # shared memory tensor to avoid an extra copy
            numel = sum([x.numel() for x in batch])
            storage = elem.storage()._new_shared(numel)
            out = elem.new(storage)
        return torch.stack(batch, 0, out=out)
    elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
            and elem_type.__name__ != 'string_':
        if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap':
            # array of string classes and object
            if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
                raise TypeError(default_collate_err_msg_format.format(elem.dtype))

            return default_collate([torch.as_tensor(b) for b in batch])
        elif elem.shape == ():  # scalars
            return torch.as_tensor(batch)
    elif isinstance(elem, float):
        return torch.tensor(batch, dtype=torch.float64)
    elif isinstance(elem, int_classes):
        return torch.tensor(batch)
    elif isinstance(elem, string_classes):
        return batch
    elif isinstance(elem, container_abcs.Mapping):
        return {key: default_collate([d[key] for d in batch]) for key in elem}
    elif isinstance(elem, tuple) and hasattr(elem, '_fields'):  # namedtuple
        return elem_type(*(default_collate(samples) for samples in zip(*batch)))
    elif isinstance(elem, container_abcs.Sequence):
        # check to make sure that the elements in batch have consistent size
        it = iter(batch)
        elem_size = len(next(it))
        if not all(len(elem) == elem_size for elem in it):
            raise RuntimeError('each element in list of batch should be of equal size')
        transposed = zip(*batch)
        return [default_collate(samples) for samples in transposed]

    raise TypeError(default_collate_err_msg_format.format(elem_type))

  看到倒數第三行了么?這就是為什么會報錯的原因了。所以如果可以,我建議還是自己設定mycollate()函數,因為源碼里如果你的dataset輸出的元素不是tensor類型,那么將會按照它的方式來重新組織來返回,不同類別返回的東西是不一樣的,大家可以看看源碼。

 

 

 

 

參考網站:

Pytorch的第一步:(1) Dataset類的使用 - 簡書 (jianshu.com)

Pytorch 中的數據類型 torch.utils.data.DataLoader 參數詳解_Never-Giveup的博客-CSDN博客_dataloader參數

RuntimeError: each element in list of batch should be of equal size_NLP新手村成員的博客-CSDN博客


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM