整理一下pytorch獲取的流程:
- 創建Dataset對象
- 創建DataLoader對象,裝載有dataset對象
- 循環DataLoader對象,DataLoader.__iter__返回的是DataLoaderIter對象
dataset = MyDataset()
dataloader = DataLoader(dataset)
num_epoches = 100
for epoch in range(num_epoches):
for data in dataloader:
....
根據源碼分析:torch.utils.data
1 - Dataset:
class Dataset(object):
"""An abstract class representing a Dataset.
All other datasets should subclass it. All subclasses should override
``__len__``, that provides the size of the dataset, and ``__getitem__``,
supporting integer indexing in range from 0 to len(self) exclusive.
"""
def __getitem__(self, index):
raise NotImplementedError
def __len__(self):
raise NotImplementedError
def __add__(self, other):
return ConcatDataset([self, other])
Dataset這是一個抽象類,不能實例化,需要重寫類方法,關鍵點有兩個:
-
__getitem__ 這個很重要,規定了如何讀數據,比如常用的transform
-
__len__ 這個就是返回數據集的長度,比如:return len(self.data)
2 - DataLoader:
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
batch_sampler=None, num_workers=0, collate_fn=None,
pin_memory=False, drop_last=False, timeout=0,
worker_init_fn=None)
先看一下主要參數:
- dataset:就是 torch.utils.data.Dataset 類的實例。也就是說為了使用 DataLoader 類,需要先定義一個 torch.utils.data.Dataset 類的實例。
- batch_size:每一個批次需要加載的訓練樣本個數。
- shuffle:如果設置為 True 表示訓練樣本數據會被隨機打亂,默認值為 False。一般會設置為 True 。
- sampler:自定義從數據集中取樣本的策略,如果指定這個參數,那么 shuffle 必須為 False 。從源碼中可以看到,如果指定了該參數,同時 shuffle 設定為 True,DataLoader 的 __init__ 函數就會拋出一個異常 。
- batch_sampler:與 sampler 類似,但是一次只返回一個 batch 的 indices(索引),需要注意的是,一旦指定了這個參數,那么 batch_size,shuffle,sampler,drop_last 就不能再指定了。源碼中同樣做了限制。
- num_workers:表示會使用多少個線程來加載訓練數據;默認值為 0,表示數據加載直接在主線程中進行。
- collate_fn:對每一個 batch 的數據做一些你想要的操作。一個例子,https://zhuanlan.zhihu.com/p/346332974
- pin_memory:把數據轉移到和 GPU 相關聯的 CPU 內存,加速 GPU 載入數據的速度。
- drop_last:比如你的batch_size設置為 32,而一個 epoch 只有 100 個樣本;如果設置為 True,那么訓練的時候后面的 4 個就被扔掉了。如果為 False(默認),那么會繼續正常執行,只是最后的 batch_size 會小一點。
- timeout:加載一個 batch 數據的超時時間。
- worker_init_fn:指定每個數據加載線程的入口函數。
源碼分析:
class DataLoader(object):
def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None,
batch_sampler=None,
num_workers=0, collate_fn=default_collate, pin_memory=False,
drop_last=False):
self.dataset = dataset
self.batch_size = batch_size
self.num_workers = num_workers
self.collate_fn = collate_fn
self.pin_memory = pin_memory
self.drop_last = drop_last
if batch_sampler is not None:
if batch_size > 1 or shuffle or sampler is not None or drop_last:
raise ValueError('batch_sampler is mutually exclusive with '
'batch_size, shuffle, sampler, and drop_last')
if sampler is not None and shuffle:
raise ValueError('sampler is mutually exclusive with shuffle')
if batch_sampler is None:
if sampler is None:
if shuffle:
# dataset.__len__() 在 Sampler 中被使用。
# 目的是生成一個 長度為 len(dataset) 的 序列索引(隨機的)。
sampler = RandomSampler(dataset)
else:
# dataset.__len__() 在 Sampler 中被使用。
# 目的是生成一個 長度為 len(dataset) 的 序列索引(順序的)。
sampler = SequentialSampler(dataset)
# Sampler 是個迭代器,一次之只返回一個 索引
# BatchSampler 也是個迭代器,但是一次返回 batch_size 個 索引
batch_sampler = BatchSampler(sampler, batch_size, drop_last)
self.sampler = sampler
self.batch_sampler = batch_sampler
def __iter__(self):
return DataLoaderIter(self)
def __len__(self):
return len(self.batch_sampler)
可以發現__iter__返回的是DataLoaderIter
3 - DataLoaderIter
先看init初始化:
if self.num_workers > 0:
self.worker_init_fn = loader.worker_init_fn
# 定義了workers相同數量個Queue並放置在index_queues這個list中,
# 這些Queue與worker一一對應,用來給worker傳遞“工作內容”
self.index_queues = [multiprocessing.Queue() for _ in range(self.num_workers)]
# worker_queue_idx用於下一個工作的workre序號,主進程輪詢使用不同workers
self.worker_queue_idx = 0
# 各個workre將自己所取得的數據傳遞給wokrker_result_queue,供主進程fetch
self.worker_result_queue = multiprocessing.SimpleQueue()
# 記錄當前時刻分配了多少個任務(可能有處於等待狀態的任務)
self.batches_outstanding = 0
self.worker_pids_set = False
self.shutdown = False
# 發送出去數據的編號
self.send_idx = 0
# 接受到數據的編號
self.rcvd_idx = 0
# 緩存區
self.reorder_dict = {}
self.workers = [
multiprocessing.Process(
target=_worker_loop,
args=(self.dataset, self.index_queues[i],
self.worker_result_queue, self.collate_fn, base_seed + i,
self.worker_init_fn, i))
for i in range(self.num_workers)]
# 初始化相應的進程,目標函數為_worker_loop
# 參數:dataset(用於數據讀取),index_queues[i]為worker對應的index_queue
# 以及用於輸出的queue
# 此處主要用於數據讀取后的pin_memory操作,不影響多進程主邏輯,暫不展開
if self.pin_memory or self.timeout > 0:
...
else:
self.data_queue = self.worker_result_queue
for w in self.workers:
w.daemon = True # ensure that the worker exits on process exit
# 將父進程設置為守護進程,保證父進程結束后,worker進程也結束,必須設置在start之前
w.start()
# 下面是一些系統信號處理邏輯,對這方面我還不太熟悉就不介紹了。
_update_worker_pids(id(self), tuple(w.pid for w in self.workers))
_set_SIGCHLD_handler()
self.worker_pids_set = True
# 初始化后生成2*num_workers數量個prefetch的數據,使dataloader提前工作,提升整體效率。
# prime the prefetch loop
for _ in range(2 * self.num_workers):
self._put_indices()
init過程有兩個函數,一個是worker_loop,另個是put_indices
a. 先看worker_loop:
def _worker_loop(dataset, index_queue, data_queue, collate_fn, seed, init_fn, worker_id):
global _use_shared_memory
_use_shared_memory = True
# Intialize C side signal handlers for SIGBUS and SIGSEGV. Python signal
# module's handlers are executed after Python returns from C low-level
# handlers, likely when the same fatal signal happened again already.
# https://docs.python.org/3/library/signal.html Sec. 18.8.1.1
_set_worker_signal_handlers()
torch.set_num_threads(1)
random.seed(seed)
torch.manual_seed(seed)
if init_fn is not None:
init_fn(worker_id)
# 父進程狀態監測
watchdog = ManagerWatchdog()
# 死循環查詢是否有任務傳進來
while True:
try:
# 從index_queue獲取相應數據
r = index_queue.get(timeout=MANAGER_STATUS_CHECK_INTERVAL)
except queue.Empty:
if watchdog.is_alive():
continue
else:
break
if r is None:
break
idx, batch_indices = r
try:
# 獲得以后for循環進行讀取數據讀取,此處和單進程的工作原理是一樣的
# 因此時間花費和batchsize數量呈線性關系
samples = collate_fn([dataset[i] for i in batch_indices])
# 經過collate_fn后變成torch.Tensor
except Exception:
# 異常處理
data_queue.put((idx, ExceptionWrapper(sys.exc_info())))
else:
# 通過data_queue傳回處理好的batch數據
data_queue.put((idx, samples))
# 顯示刪除中間變量,降低內存消耗
del samples
這里就是不停地輪詢,從index_queues隊列里獲得索引,然后通過collate_fn函數和索引獲取tensor,然后塞入data_queue。
b. 再看put_indices
def _put_indices(self):
assert self.batches_outstanding < 2 * self.num_workers
# 默認設定是只允許分配2*num_workers個任務,保證內存等資源不被耗盡
indices = next(self.sample_iter, None)
# 從sample_iter中拿到dataset中下一輪次的索引,用於fetch數據
if indices is None:
return
self.index_queues[self.worker_queue_idx].put((self.send_idx, indices))
# 輪詢選擇worker,找到其對應的隊列,向其中發送工作內容(數據編號,數據索引)
self.worker_queue_idx = (self.worker_queue_idx + 1) % self.num_workers
# worker_queue_idx自增
self.batches_outstanding += 1
# 任務分配數+1
self.send_idx += 1
# 已發送任務總數+1(下批數據編號)
這個就是把索引塞進隊列index_queues
以上就是init,當for循環時,會調用next:
c. __next__返回一個batch
def __next__(self):
if self.num_workers == 0: # same-process loading (主進程阻塞式讀取數據)
indices = next(self.sample_iter) # may raise StopIteration
batch = self.collate_fn([self.dataset[i] for i in indices])
if self.pin_memory:
batch = pin_memory_batch(batch)
return batch
# check if the next sample has already been generated
# 先查看數據是否在緩存dict中
if self.rcvd_idx in self.reorder_dict:
batch = self.reorder_dict.pop(self.rcvd_idx)
return self._process_next_batch(batch)
# 異常處理
if self.batches_outstanding == 0:
self._shutdown_workers()
raise StopIteration
while True:
assert (not self.shutdown and self.batches_outstanding > 0)
# 阻塞式的從data_queue里面獲取處理好的批數據
idx, batch = self._get_batch()
# 任務數減一
self.batches_outstanding -= 1
# 這一步可能會造成的周期阻塞現象
# 每次獲取data以后,要校驗和rcvd_idx是否一致
# 若不一致,則先把獲取到的數據放到reorder_dict這個緩存dict中,繼續死循環
# 直到獲取到相應的idx編號於rcvd_idx可以對應上,並將數據返回
if idx != self.rcvd_idx:
# store out-of-order samples
self.reorder_dict[idx] = batch
continue
return self._process_next_batch(batch)
__next__里的while True,要從data_queue里面讀到的數據idx和rcvd_idx一致才將數據返回。因此可能會存在如下這種情況:
假設num_workers=8,現在發送了8個數據給相應的worker,此時send_idx=8,rcvd_idx=0。過了一段時間以后,{1,2,3,5,6,7}進程數據准備完畢,此時主進程從data_queue讀取到相關的數據,但由於和rcvd_idx不匹配,只能將其放在緩存里。直到send_idx=0數據准備齊以后,才能將數據返回出去,隨后從緩存中彈出2,3的數據,之后又阻塞等待idx=4的數據。即輸出的數據必須保持順序性!因此在worker變多,出現這種逆序現象可能性會更大,這種現象也會出現在非num_workrers次迭代,只要相應的rcvd_idx沒有得到相關數據,則主進程就會一直等待。
d. process_next_batch
def _process_next_batch(self, batch):
# 序號對上以后,rcvd_idx自加1
self.rcvd_idx += 1
# 添加一個fetchdata任務給worker
self._put_indices()
if isinstance(batch, ExceptionWrapper):
raise batch.exc_type(batch.exc_msg)
return batch
這個函數注意的是,只有在__next__中,idx == self.rcvd_idx時才會調用,也就是可能出現多個worker已經准備好了,但是只能放在緩存區,並且無法向index_queues塞入索引,使worker無法保持活躍狀態。
最后對於for循環從dataloader獲取data總體流程:
for epoch in range(num_epoches):
for data in dataloader:
對於這個for,其實就是調用了dataloader 的__iter__() 方法, 產生了一個DataLoaderIter,如果是num_worker>0,init里就會創建多線程,並且有兩個隊列,一個是存放dataset的索引index_queues,一個是從index_queues里拿到索引,調用dataset的__getitem__()方法 (如果num_worker>0就多線程調用), 然后用collate_fn來把它們打包成batch,放到data_queue隊列里,反復調用DataLoaderIter 的__next__,從data_queue中獲取batch。
參考:
Pytorch數據讀取(Dataset, DataLoader, DataLoaderIter) https://zhuanlan.zhihu.com/p/30934236
PyTorch 之 Dataset 和 Dataloader https://zhuanlan.zhihu.com/p/339675188
PyTorch36.DataLoader源代碼剖析 https://zhuanlan.zhihu.com/p/169497395
PyTorch DataLoader初探 https://zhuanlan.zhihu.com/p/91521705
一文弄懂Pytorch的DataLoader, DataSet, Sampler之間的關系 https://zhuanlan.zhihu.com/p/76893455
