整理一下pytorch獲取的流程:
- 創建Dataset對象
- 創建DataLoader對象,裝載有dataset對象
- 循環DataLoader對象,DataLoader.__iter__返回的是DataLoaderIter對象
dataset = MyDataset() dataloader = DataLoader(dataset) num_epoches = 100 for epoch in range(num_epoches): for data in dataloader: ....
根據源碼分析:torch.utils.data
1 - Dataset:
class Dataset(object): """An abstract class representing a Dataset. All other datasets should subclass it. All subclasses should override ``__len__``, that provides the size of the dataset, and ``__getitem__``, supporting integer indexing in range from 0 to len(self) exclusive. """ def __getitem__(self, index): raise NotImplementedError def __len__(self): raise NotImplementedError def __add__(self, other): return ConcatDataset([self, other])
Dataset這是一個抽象類,不能實例化,需要重寫類方法,關鍵點有兩個:
-
__getitem__ 這個很重要,規定了如何讀數據,比如常用的transform
-
__len__ 這個就是返回數據集的長度,比如:return len(self.data)
2 - DataLoader:
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None)
先看一下主要參數:
- dataset:就是 torch.utils.data.Dataset 類的實例。也就是說為了使用 DataLoader 類,需要先定義一個 torch.utils.data.Dataset 類的實例。
- batch_size:每一個批次需要加載的訓練樣本個數。
- shuffle:如果設置為 True 表示訓練樣本數據會被隨機打亂,默認值為 False。一般會設置為 True 。
- sampler:自定義從數據集中取樣本的策略,如果指定這個參數,那么 shuffle 必須為 False 。從源碼中可以看到,如果指定了該參數,同時 shuffle 設定為 True,DataLoader 的 __init__ 函數就會拋出一個異常 。
- batch_sampler:與 sampler 類似,但是一次只返回一個 batch 的 indices(索引),需要注意的是,一旦指定了這個參數,那么 batch_size,shuffle,sampler,drop_last 就不能再指定了。源碼中同樣做了限制。
- num_workers:表示會使用多少個線程來加載訓練數據;默認值為 0,表示數據加載直接在主線程中進行。
- collate_fn:對每一個 batch 的數據做一些你想要的操作。一個例子,https://zhuanlan.zhihu.com/p/346332974
- pin_memory:把數據轉移到和 GPU 相關聯的 CPU 內存,加速 GPU 載入數據的速度。
- drop_last:比如你的batch_size設置為 32,而一個 epoch 只有 100 個樣本;如果設置為 True,那么訓練的時候后面的 4 個就被扔掉了。如果為 False(默認),那么會繼續正常執行,只是最后的 batch_size 會小一點。
- timeout:加載一個 batch 數據的超時時間。
- worker_init_fn:指定每個數據加載線程的入口函數。
源碼分析:
class DataLoader(object): def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False): self.dataset = dataset self.batch_size = batch_size self.num_workers = num_workers self.collate_fn = collate_fn self.pin_memory = pin_memory self.drop_last = drop_last if batch_sampler is not None: if batch_size > 1 or shuffle or sampler is not None or drop_last: raise ValueError('batch_sampler is mutually exclusive with ' 'batch_size, shuffle, sampler, and drop_last') if sampler is not None and shuffle: raise ValueError('sampler is mutually exclusive with shuffle') if batch_sampler is None: if sampler is None: if shuffle: # dataset.__len__() 在 Sampler 中被使用。 # 目的是生成一個 長度為 len(dataset) 的 序列索引(隨機的)。 sampler = RandomSampler(dataset) else: # dataset.__len__() 在 Sampler 中被使用。 # 目的是生成一個 長度為 len(dataset) 的 序列索引(順序的)。 sampler = SequentialSampler(dataset) # Sampler 是個迭代器,一次之只返回一個 索引 # BatchSampler 也是個迭代器,但是一次返回 batch_size 個 索引 batch_sampler = BatchSampler(sampler, batch_size, drop_last) self.sampler = sampler self.batch_sampler = batch_sampler def __iter__(self): return DataLoaderIter(self) def __len__(self): return len(self.batch_sampler)
可以發現__iter__返回的是DataLoaderIter
3 - DataLoaderIter
先看init初始化:
if self.num_workers > 0: self.worker_init_fn = loader.worker_init_fn
# 定義了workers相同數量個Queue並放置在index_queues這個list中, # 這些Queue與worker一一對應,用來給worker傳遞“工作內容” self.index_queues = [multiprocessing.Queue() for _ in range(self.num_workers)]
# worker_queue_idx用於下一個工作的workre序號,主進程輪詢使用不同workers self.worker_queue_idx = 0
# 各個workre將自己所取得的數據傳遞給wokrker_result_queue,供主進程fetch self.worker_result_queue = multiprocessing.SimpleQueue() # 記錄當前時刻分配了多少個任務(可能有處於等待狀態的任務) self.batches_outstanding = 0 self.worker_pids_set = False self.shutdown = False # 發送出去數據的編號 self.send_idx = 0 # 接受到數據的編號 self.rcvd_idx = 0 # 緩存區 self.reorder_dict = {} self.workers = [ multiprocessing.Process( target=_worker_loop, args=(self.dataset, self.index_queues[i], self.worker_result_queue, self.collate_fn, base_seed + i, self.worker_init_fn, i)) for i in range(self.num_workers)] # 初始化相應的進程,目標函數為_worker_loop # 參數:dataset(用於數據讀取),index_queues[i]為worker對應的index_queue # 以及用於輸出的queue # 此處主要用於數據讀取后的pin_memory操作,不影響多進程主邏輯,暫不展開 if self.pin_memory or self.timeout > 0: ... else: self.data_queue = self.worker_result_queue for w in self.workers: w.daemon = True # ensure that the worker exits on process exit # 將父進程設置為守護進程,保證父進程結束后,worker進程也結束,必須設置在start之前 w.start() # 下面是一些系統信號處理邏輯,對這方面我還不太熟悉就不介紹了。 _update_worker_pids(id(self), tuple(w.pid for w in self.workers)) _set_SIGCHLD_handler() self.worker_pids_set = True # 初始化后生成2*num_workers數量個prefetch的數據,使dataloader提前工作,提升整體效率。 # prime the prefetch loop for _ in range(2 * self.num_workers): self._put_indices()
init過程有兩個函數,一個是worker_loop,另個是put_indices
a. 先看worker_loop:
def _worker_loop(dataset, index_queue, data_queue, collate_fn, seed, init_fn, worker_id): global _use_shared_memory _use_shared_memory = True # Intialize C side signal handlers for SIGBUS and SIGSEGV. Python signal # module's handlers are executed after Python returns from C low-level # handlers, likely when the same fatal signal happened again already. # https://docs.python.org/3/library/signal.html Sec. 18.8.1.1 _set_worker_signal_handlers() torch.set_num_threads(1) random.seed(seed) torch.manual_seed(seed) if init_fn is not None: init_fn(worker_id) # 父進程狀態監測 watchdog = ManagerWatchdog() # 死循環查詢是否有任務傳進來 while True: try: # 從index_queue獲取相應數據 r = index_queue.get(timeout=MANAGER_STATUS_CHECK_INTERVAL) except queue.Empty: if watchdog.is_alive(): continue else: break if r is None: break idx, batch_indices = r try: # 獲得以后for循環進行讀取數據讀取,此處和單進程的工作原理是一樣的 # 因此時間花費和batchsize數量呈線性關系 samples = collate_fn([dataset[i] for i in batch_indices]) # 經過collate_fn后變成torch.Tensor except Exception: # 異常處理 data_queue.put((idx, ExceptionWrapper(sys.exc_info()))) else: # 通過data_queue傳回處理好的batch數據 data_queue.put((idx, samples)) # 顯示刪除中間變量,降低內存消耗 del samples
這里就是不停地輪詢,從index_queues隊列里獲得索引,然后通過collate_fn函數和索引獲取tensor,然后塞入data_queue。
b. 再看put_indices
def _put_indices(self): assert self.batches_outstanding < 2 * self.num_workers # 默認設定是只允許分配2*num_workers個任務,保證內存等資源不被耗盡 indices = next(self.sample_iter, None) # 從sample_iter中拿到dataset中下一輪次的索引,用於fetch數據 if indices is None: return self.index_queues[self.worker_queue_idx].put((self.send_idx, indices)) # 輪詢選擇worker,找到其對應的隊列,向其中發送工作內容(數據編號,數據索引) self.worker_queue_idx = (self.worker_queue_idx + 1) % self.num_workers # worker_queue_idx自增 self.batches_outstanding += 1 # 任務分配數+1 self.send_idx += 1 # 已發送任務總數+1(下批數據編號)
這個就是把索引塞進隊列index_queues
以上就是init,當for循環時,會調用next:
c. __next__返回一個batch
def __next__(self): if self.num_workers == 0: # same-process loading (主進程阻塞式讀取數據) indices = next(self.sample_iter) # may raise StopIteration batch = self.collate_fn([self.dataset[i] for i in indices]) if self.pin_memory: batch = pin_memory_batch(batch) return batch # check if the next sample has already been generated # 先查看數據是否在緩存dict中 if self.rcvd_idx in self.reorder_dict: batch = self.reorder_dict.pop(self.rcvd_idx) return self._process_next_batch(batch) # 異常處理 if self.batches_outstanding == 0: self._shutdown_workers() raise StopIteration while True: assert (not self.shutdown and self.batches_outstanding > 0) # 阻塞式的從data_queue里面獲取處理好的批數據 idx, batch = self._get_batch() # 任務數減一 self.batches_outstanding -= 1 # 這一步可能會造成的周期阻塞現象 # 每次獲取data以后,要校驗和rcvd_idx是否一致 # 若不一致,則先把獲取到的數據放到reorder_dict這個緩存dict中,繼續死循環 # 直到獲取到相應的idx編號於rcvd_idx可以對應上,並將數據返回 if idx != self.rcvd_idx: # store out-of-order samples self.reorder_dict[idx] = batch continue return self._process_next_batch(batch)
__next__里的while True,要從data_queue里面讀到的數據idx和rcvd_idx一致才將數據返回。因此可能會存在如下這種情況:
假設num_workers=8,現在發送了8個數據給相應的worker,此時send_idx=8,rcvd_idx=0。過了一段時間以后,{1,2,3,5,6,7}進程數據准備完畢,此時主進程從data_queue讀取到相關的數據,但由於和rcvd_idx不匹配,只能將其放在緩存里。直到send_idx=0數據准備齊以后,才能將數據返回出去,隨后從緩存中彈出2,3的數據,之后又阻塞等待idx=4的數據。即輸出的數據必須保持順序性!因此在worker變多,出現這種逆序現象可能性會更大,這種現象也會出現在非num_workrers次迭代,只要相應的rcvd_idx沒有得到相關數據,則主進程就會一直等待。
d. process_next_batch
def _process_next_batch(self, batch): # 序號對上以后,rcvd_idx自加1 self.rcvd_idx += 1 # 添加一個fetchdata任務給worker self._put_indices() if isinstance(batch, ExceptionWrapper): raise batch.exc_type(batch.exc_msg) return batch
這個函數注意的是,只有在__next__中,idx == self.rcvd_idx時才會調用,也就是可能出現多個worker已經准備好了,但是只能放在緩存區,並且無法向index_queues塞入索引,使worker無法保持活躍狀態。
最后對於for循環從dataloader獲取data總體流程:
for epoch in range(num_epoches): for data in dataloader:
對於這個for,其實就是調用了dataloader 的__iter__() 方法, 產生了一個DataLoaderIter,如果是num_worker>0,init里就會創建多線程,並且有兩個隊列,一個是存放dataset的索引index_queues,一個是從index_queues里拿到索引,調用dataset的__getitem__()方法 (如果num_worker>0就多線程調用), 然后用collate_fn來把它們打包成batch,放到data_queue隊列里,反復調用DataLoaderIter 的__next__,從data_queue中獲取batch。
參考:
Pytorch數據讀取(Dataset, DataLoader, DataLoaderIter) https://zhuanlan.zhihu.com/p/30934236
PyTorch 之 Dataset 和 Dataloader https://zhuanlan.zhihu.com/p/339675188
PyTorch36.DataLoader源代碼剖析 https://zhuanlan.zhihu.com/p/169497395
PyTorch DataLoader初探 https://zhuanlan.zhihu.com/p/91521705
一文弄懂Pytorch的DataLoader, DataSet, Sampler之間的關系 https://zhuanlan.zhihu.com/p/76893455