這篇文章主要探討一下,Dataset類以及DataLoader類的使用以及注意事項。Dataset類主要是用於原始數據的讀取或者基本的數據處理(比如在NLP任務中常常需要把文字轉化為對應字典ids,這個步驟就可以放在Dataset中執行)。DataLoader,是進一步對Dataset的處理,Dataset得到的數據集你可以理解為是個"列表"(可以根據index取出某個特定位置的數據),而DataLoder就是把這個數據集(Dataset)根據你設定的batch_size划分成很多個“子數據集”,每個“子數據集”中的元素數量就是batch_size。
DataLoader為什么要把Dataset划分成多個”子數據集“呢?因為一次性把所有的數據放進模型會導致內存溢出,而且模型的迭代會很慢。下面我們就深度解析下Dataset和DataLoader的使用方式。
一、Dataset的使用
這里說到的Dataset其實就是,torch.utils.data.Dataset類 ,換句話說我們需要創建一個Dataset類,使用類的繼承就可以了。既然是繼承類,那么肯定會修改一些父類(torch.utils.data.Dataset類 )的方法來適應我們的真實數據和邏輯。而我們主要要重寫的就是,__init__(),__len__(),__getitem__(),這三個方法分別是以下作用:
__init__
方法:進行類的初始化,一般是用來讀取原始數據。__getitem__
方法:根據下標對每一個數據進行進一步的處理。return:希望通過dataset[index]在數據集中取出的元素__len__
方法:return:數據集的數量(int)
下面用一個例子來大致說明下Dataset該怎么構建,並且如何使用。
from torch.utils.data import Dataset import torch def MyTokenizer(sentence): src_vocab = {'度':0,'上':1,'世':2,'中':3, '為':4,'人':5,'偉':6,'你':7,'務':8,'國':9, '大':10,'我':11,'是':12,'最':13,'服':14, '民':15,'愛':16,'界':17,'的':18} enc_input = [src_vocab[n] for n in sentence] if len(enc_input) < 12: ## 如果enc_input的長度小於12,則用100來補足,使得enc_input長度為12. enc_input = enc_input+(12-len(enc_input))*[100] return enc_input class MyDataset(Dataset): def __init__(self, data): self.data = data self.tokenizer = MyTokenizer def __len__(self): return len(self.data) def __getitem__(self, index): sentence = self.data[index] return torch.tensor(self.tokenizer(sentence)) data = ['我愛你中國', '中國是世界上最偉大的國度', '為人民服務','你愛我'] dataset = MyDataset(data)
簡單介紹一下,函數MyTokenizer(sentence):把一個句子根據字典,轉化為一個列表[...],例如”我愛你“——>[11, 16, 7, 100,100,...100]。這里為了簡便,我是用的數據就是四句話。我在初始化(__init__)我的MyDataset類時,把數據儲存下來,並且定義了我的編碼器Mytokenizer。(這里為什么輸出列表末尾會有100呢,主要是使得每一個數據長度是一樣的,為后面進入DataLoader做准備,其實這個操作就叫做padding)
__len__(self):這里返回了我傳入數據集的大小。而__getitem__(self,index):中index指的是數據下標,根據這個下標提取出原始數據(self.data中的一句話),並且把這句話傳入到self.tokenizer進行編碼,最后返回編碼的結果(一個列表[....]),可以看到__getitem__這個函數就是根據index來處理每一個拿出來的原始數據的,你對原始數據的所有處理都可以放在這里。我們最后一行代碼是完成了MyDataset的實例化。看一看這個實例化之后的結果。
print(len(dataset)) print(dataset[0])
3
[11, 16, 7, 3, 9]
這里返回的tensor([11,16,7,3,9]),其實就是”我愛你中國“經過編碼之后的結果(可以對照上面的字典看看)。
其實說白了Dataset就是一個數據處理器,把數據收集起來,並且進行對每一個index的數據進行處理,最后輸出。有人會問為啥不先處理好這些數據呢?其實是因為DataLoader只能接受torch.utils.data.Dataset類作為傳入參數,因此用其他任意的數據結構都沒辦法放到DataLoader里面,這樣就沒法自動根據batch_size拆分成”子數據集“。因此Dataset是我們必須構建的,就算是我的數據不想進一步處理,也必須寫一個以上的最簡單的MyDataset類(直接傳入啥,輸出啥的類)。
二、DataLoader的使用
先放官方文檔:官方文檔
剛才說到Dataset的構建是為了放進到DataLoader里,為啥非要放到這里面呢?其根本原因是DataLoader中有很多好用的設置可以讓我們更好的處理數據,比如參數shuffle,可以讓Dataset中的數據打亂重新排列再進行分批次,num_workers參數可以設定安排多少個進程來加載數據(加速)。一般情況下我們不需要重寫DataLoader類,只需要實例化就可以了。例如我們把上面創建好的Dataset實例——dataset傳入到DataLoader中構建實例。
這里一定要注意,每個batch(子集)里的長度一定要一致,不然會報錯“RuntimeError: each element in list of batch should be of equal size”。(這也就是為什么,在建立Dataset的時候我會用100來吧不足12長度的句子填充成統一長度,因為我舉的例子中沒有超過12的句子,所以不存在切割句子,真實情況需要按你自己的數據需求,但是一定要保證出來的數據要一樣長,至於為什么一會后面說)。
from torch.utils.data import DataLoader myDataloader = DataLoader(dataset, shuffle=True, batch_size=2)
這個myDataloader就是DataLoader的實例,已經被分為了2個數據為一個batch,接下來我們打印一下每個batch(由於我們只有4句話,2個樣本為一個batch那么其實就只有2個batch,所以可以打印來看看)。
for batch in myDataloader: print(batch) print('===============================')
tensor([[ 7, 16, 11, 100, 100, 100, 100, 100, 100, 100, 100, 100], [ 4, 5, 15, 14, 8, 100, 100, 100, 100, 100, 100, 100]]) =============================== tensor([[ 3, 9, 12, 2, 17, 1, 13, 6, 10, 18, 9, 0], [ 11, 16, 7, 3, 9, 100, 100, 100, 100, 100, 100, 100]]) ===============================
可以看到每個batch其實是一個tensor,維度是(2,12)。每個tensor的每一行其實就是一個dataset里的一個樣本。並且要注意每個樣本已經不是按照原本的順序排列了。
三、collate_fn參數的使用
在DataLoader里,除了上面提到的shuffle參數和batch_size參數以外,還有一個非常重要的傳入參數collate_fn,這個參數傳入的是一個函數,這個函數主要是對每個batch進行處理,最終輸出一個batch的返回值,換句話說collate_fn函數的返回值,就是遍歷DataLoader的時候每個“batch”的返回值了(類似於上面例子中的二維tensor)。下面我寫一個函數,讓大家看看到底是怎么處理的。
def mycollate(item): sample1, sample2 = item return {'第一個樣本':sample1,'第二個樣本':sample2} from torch.utils.data import DataLoader myDataloader = DataLoader(dataset, shuffle=True, batch_size=2, collate_fn=mycollate)
我們現在再來打印一下myDataloader的每個元素。
for batch in myDataloader: print(batch) print('===============================')
{'第一個樣本': tensor([ 11, 16, 7, 3, 9, 100, 100, 100, 100, 100, 100, 100]), '第二個樣本': tensor([ 7, 16, 11, 100, 100, 100, 100, 100, 100, 100, 100, 100])} =============================== {'第一個樣本': tensor([ 3, 9, 12, 2, 17, 1, 13, 6, 10, 18, 9, 0]), '第二個樣本': tensor([ 4, 5, 15, 14, 8, 100, 100, 100, 100, 100, 100, 100])} ===============================
可以看到,這個時候打印myDataloader的每個元素,就變成我在mycollate()函數中的返回值了。或許會不明白,我在mycollate()函數中這個傳入的item是什么?其實這個item是個元組,元組的每個元素就是dataset的每個元素(tensor([3,9,12,....])),item的元素個數其實就是batch_size,這里的batch_size是2,所以我在mycollate()中用了兩個變量來接收(換句話說要是我把batch_size換成2以外的其他數字,就會報錯了)。
可以看到其實我們在DataLoader的時候依然可以使用函數來處理我們的數據,換句話說我們完全可以把tokenizer函數放到mycollate()函數中。
好了現在我們可以來解釋為什么我在第二節的時候要求每個batch的數據要一樣長了,那是因為當你不給定collate_fn這個參數的時候,會自動調用一個函數叫做default_collate(),大家可以粗略的看看這個內置函數的源碼:
def default_collate(batch): r"""Puts each data field into a tensor with outer dimension batch size""" elem = batch[0] elem_type = type(elem) if isinstance(elem, torch.Tensor): out = None if torch.utils.data.get_worker_info() is not None: # If we're in a background process, concatenate directly into a # shared memory tensor to avoid an extra copy numel = sum([x.numel() for x in batch]) storage = elem.storage()._new_shared(numel) out = elem.new(storage) return torch.stack(batch, 0, out=out) elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \ and elem_type.__name__ != 'string_': if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap': # array of string classes and object if np_str_obj_array_pattern.search(elem.dtype.str) is not None: raise TypeError(default_collate_err_msg_format.format(elem.dtype)) return default_collate([torch.as_tensor(b) for b in batch]) elif elem.shape == (): # scalars return torch.as_tensor(batch) elif isinstance(elem, float): return torch.tensor(batch, dtype=torch.float64) elif isinstance(elem, int_classes): return torch.tensor(batch) elif isinstance(elem, string_classes): return batch elif isinstance(elem, container_abcs.Mapping): return {key: default_collate([d[key] for d in batch]) for key in elem} elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple return elem_type(*(default_collate(samples) for samples in zip(*batch))) elif isinstance(elem, container_abcs.Sequence): # check to make sure that the elements in batch have consistent size it = iter(batch) elem_size = len(next(it)) if not all(len(elem) == elem_size for elem in it): raise RuntimeError('each element in list of batch should be of equal size') transposed = zip(*batch) return [default_collate(samples) for samples in transposed] raise TypeError(default_collate_err_msg_format.format(elem_type))
看到倒數第三行了么?這就是為什么會報錯的原因了。所以如果可以,我建議還是自己設定mycollate()函數,因為源碼里如果你的dataset輸出的元素不是tensor類型,那么將會按照它的方式來重新組織來返回,不同類別返回的東西是不一樣的,大家可以看看源碼。
參考網站:
Pytorch的第一步:(1) Dataset類的使用 - 簡書 (jianshu.com)
Pytorch 中的數據類型 torch.utils.data.DataLoader 參數詳解_Never-Giveup的博客-CSDN博客_dataloader參數
RuntimeError: each element in list of batch should be of equal size_NLP新手村成員的博客-CSDN博客