pytorch的dataset與dataloader解析


整理一下pytorch獲取的流程:

  1. 創建Dataset對象
  2. 創建DataLoader對象,裝載有dataset對象
  3. 循環DataLoader對象,DataLoader.__iter__返回的是DataLoaderIter對象
dataset = MyDataset()
dataloader = DataLoader(dataset)
num_epoches = 100
for epoch in range(num_epoches):
    for data in dataloader:
        ....

根據源碼分析:torch.utils.data

 

1 - Dataset:

class Dataset(object):
    """An abstract class representing a Dataset.
    All other datasets should subclass it. All subclasses should override
    ``__len__``, that provides the size of the dataset, and ``__getitem__``,
    supporting integer indexing in range from 0 to len(self) exclusive.
    """

    def __getitem__(self, index):
        raise NotImplementedError

    def __len__(self):
        raise NotImplementedError

    def __add__(self, other):
        return ConcatDataset([self, other])

Dataset這是一個抽象類,不能實例化,需要重寫類方法,關鍵點有兩個:

  • __getitem__ 這個很重要,規定了如何讀數據,比如常用的transform
  • __len__ 這個就是返回數據集的長度,比如:return len(self.data)

2 - DataLoader:

DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
           batch_sampler=None, num_workers=0, collate_fn=None,
           pin_memory=False, drop_last=False, timeout=0,
           worker_init_fn=None)

先看一下主要參數:

  • dataset:就是 torch.utils.data.Dataset 類的實例。也就是說為了使用 DataLoader 類,需要先定義一個 torch.utils.data.Dataset 類的實例。
  • batch_size:每一個批次需要加載的訓練樣本個數。
  • shuffle:如果設置為 True 表示訓練樣本數據會被隨機打亂,默認值為 False。一般會設置為 True 。
  • sampler:自定義從數據集中取樣本的策略,如果指定這個參數,那么 shuffle 必須為 False 。從源碼中可以看到,如果指定了該參數,同時 shuffle 設定為 True,DataLoader 的 __init__ 函數就會拋出一個異常 。
  • batch_sampler:與 sampler 類似,但是一次只返回一個 batch 的 indices(索引),需要注意的是,一旦指定了這個參數,那么 batch_size,shuffle,sampler,drop_last 就不能再指定了。源碼中同樣做了限制。
  • num_workers:表示會使用多少個線程來加載訓練數據;默認值為 0,表示數據加載直接在主線程中進行。
  • collate_fn:對每一個 batch 的數據做一些你想要的操作。一個例子,https://zhuanlan.zhihu.com/p/346332974
  • pin_memory:把數據轉移到和 GPU 相關聯的 CPU 內存,加速 GPU 載入數據的速度。
  • drop_last:比如你的batch_size設置為 32,而一個 epoch 只有 100 個樣本;如果設置為 True,那么訓練的時候后面的 4 個就被扔掉了。如果為 False(默認),那么會繼續正常執行,只是最后的 batch_size 會小一點。
  • timeout:加載一個 batch 數據的超時時間。
  • worker_init_fn:指定每個數據加載線程的入口函數。

源碼分析:

class DataLoader(object):
    def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, 
                 batch_sampler=None,
                 num_workers=0, collate_fn=default_collate, pin_memory=False, 
                 drop_last=False):
        self.dataset = dataset
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.collate_fn = collate_fn
        self.pin_memory = pin_memory
        self.drop_last = drop_last

        if batch_sampler is not None:
            if batch_size > 1 or shuffle or sampler is not None or drop_last:
                raise ValueError('batch_sampler is mutually exclusive with '
                                 'batch_size, shuffle, sampler, and drop_last')

        if sampler is not None and shuffle:
            raise ValueError('sampler is mutually exclusive with shuffle')

        if batch_sampler is None:
            if sampler is None:
                if shuffle:
                    # dataset.__len__() 在 Sampler 中被使用。
                    # 目的是生成一個 長度為 len(dataset) 的 序列索引(隨機的)。
                    sampler = RandomSampler(dataset)
                else:
                    # dataset.__len__() 在 Sampler 中被使用。
                    # 目的是生成一個 長度為 len(dataset) 的 序列索引(順序的)。
                    sampler = SequentialSampler(dataset)
            # Sampler 是個迭代器,一次之只返回一個 索引
            # BatchSampler 也是個迭代器,但是一次返回 batch_size 個 索引
            batch_sampler = BatchSampler(sampler, batch_size, drop_last)

        self.sampler = sampler
        self.batch_sampler = batch_sampler

    def __iter__(self):
        return DataLoaderIter(self)

    def __len__(self):
        return len(self.batch_sampler) 

可以發現__iter__返回的是DataLoaderIter

 

3 - DataLoaderIter

先看init初始化:

if self.num_workers > 0:
    self.worker_init_fn = loader.worker_init_fn
# 定義了workers相同數量個Queue並放置在index_queues這個list中, # 這些Queue與worker一一對應,用來給worker傳遞“工作內容” self.index_queues = [multiprocessing.Queue() for _ in range(self.num_workers)]
# worker_queue_idx用於下一個工作的workre序號,主進程輪詢使用不同workers self.worker_queue_idx = 0
# 各個workre將自己所取得的數據傳遞給wokrker_result_queue,供主進程fetch self.worker_result_queue = multiprocessing.SimpleQueue() # 記錄當前時刻分配了多少個任務(可能有處於等待狀態的任務) self.batches_outstanding = 0 self.worker_pids_set = False self.shutdown = False # 發送出去數據的編號 self.send_idx = 0 # 接受到數據的編號 self.rcvd_idx = 0 # 緩存區 self.reorder_dict = {} self.workers = [ multiprocessing.Process( target=_worker_loop, args=(self.dataset, self.index_queues[i], self.worker_result_queue, self.collate_fn, base_seed + i, self.worker_init_fn, i)) for i in range(self.num_workers)] # 初始化相應的進程,目標函數為_worker_loop # 參數:dataset(用於數據讀取),index_queues[i]為worker對應的index_queue # 以及用於輸出的queue # 此處主要用於數據讀取后的pin_memory操作,不影響多進程主邏輯,暫不展開 if self.pin_memory or self.timeout > 0: ... else: self.data_queue = self.worker_result_queue for w in self.workers: w.daemon = True # ensure that the worker exits on process exit # 將父進程設置為守護進程,保證父進程結束后,worker進程也結束,必須設置在start之前 w.start() # 下面是一些系統信號處理邏輯,對這方面我還不太熟悉就不介紹了。 _update_worker_pids(id(self), tuple(w.pid for w in self.workers)) _set_SIGCHLD_handler() self.worker_pids_set = True # 初始化后生成2*num_workers數量個prefetch的數據,使dataloader提前工作,提升整體效率。 # prime the prefetch loop for _ in range(2 * self.num_workers): self._put_indices()

init過程有兩個函數,一個是worker_loop,另個是put_indices

a. 先看worker_loop:

def _worker_loop(dataset, index_queue, data_queue, collate_fn, seed, init_fn, worker_id):
    global _use_shared_memory
    _use_shared_memory = True

    # Intialize C side signal handlers for SIGBUS and SIGSEGV. Python signal
    # module's handlers are executed after Python returns from C low-level
    # handlers, likely when the same fatal signal happened again already.
    # https://docs.python.org/3/library/signal.html Sec. 18.8.1.1
    _set_worker_signal_handlers()

    torch.set_num_threads(1)
    random.seed(seed)
    torch.manual_seed(seed)

    if init_fn is not None:
        init_fn(worker_id)
    
    # 父進程狀態監測
    watchdog = ManagerWatchdog()
    
    # 死循環查詢是否有任務傳進來
    while True:
        try:
            # 從index_queue獲取相應數據
            r = index_queue.get(timeout=MANAGER_STATUS_CHECK_INTERVAL)
        except queue.Empty:
            if watchdog.is_alive():
                continue
            else:
                break
        if r is None:
            break
        idx, batch_indices = r
        try:
            # 獲得以后for循環進行讀取數據讀取,此處和單進程的工作原理是一樣的
            # 因此時間花費和batchsize數量呈線性關系
            samples = collate_fn([dataset[i] for i in batch_indices])
            # 經過collate_fn后變成torch.Tensor
        except Exception:
            # 異常處理
            data_queue.put((idx, ExceptionWrapper(sys.exc_info())))
        else:
            # 通過data_queue傳回處理好的batch數據
            data_queue.put((idx, samples))
            # 顯示刪除中間變量,降低內存消耗
            del samples

這里就是不停地輪詢,從index_queues隊列里獲得索引,然后通過collate_fn函數和索引獲取tensor,然后塞入data_queue

 

b. 再看put_indices

def _put_indices(self):
    assert self.batches_outstanding < 2 * self.num_workers
    # 默認設定是只允許分配2*num_workers個任務,保證內存等資源不被耗盡
    indices = next(self.sample_iter, None)
    # 從sample_iter中拿到dataset中下一輪次的索引,用於fetch數據
    if indices is None:
        return
    self.index_queues[self.worker_queue_idx].put((self.send_idx, indices))
    # 輪詢選擇worker,找到其對應的隊列,向其中發送工作內容(數據編號,數據索引)
    self.worker_queue_idx = (self.worker_queue_idx + 1) % self.num_workers
    # worker_queue_idx自增
    self.batches_outstanding += 1
    # 任務分配數+1
    self.send_idx += 1
    # 已發送任務總數+1(下批數據編號) 

這個就是把索引塞進隊列index_queues

以上就是init,當for循環時,會調用next:

 

c. __next__返回一個batch

def __next__(self):
    if self.num_workers == 0:  # same-process loading  (主進程阻塞式讀取數據)
        indices = next(self.sample_iter)  # may raise StopIteration
        batch = self.collate_fn([self.dataset[i] for i in indices])
        if self.pin_memory:
            batch = pin_memory_batch(batch)
        return batch
    
    # check if the next sample has already been generated
    # 先查看數據是否在緩存dict中
    if self.rcvd_idx in self.reorder_dict:
        batch = self.reorder_dict.pop(self.rcvd_idx)
        return self._process_next_batch(batch)
    # 異常處理
    if self.batches_outstanding == 0:
        self._shutdown_workers()
        raise StopIteration
    while True:
        assert (not self.shutdown and self.batches_outstanding > 0)
        # 阻塞式的從data_queue里面獲取處理好的批數據
        idx, batch = self._get_batch() 
        # 任務數減一
        self.batches_outstanding -= 1
        # 這一步可能會造成的周期阻塞現象
        # 每次獲取data以后,要校驗和rcvd_idx是否一致
        # 若不一致,則先把獲取到的數據放到reorder_dict這個緩存dict中,繼續死循環
        # 直到獲取到相應的idx編號於rcvd_idx可以對應上,並將數據返回
        if idx != self.rcvd_idx:
            # store out-of-order samples
            self.reorder_dict[idx] = batch
            continue
        return self._process_next_batch(batch)

__next__里的while True,要從data_queue里面讀到的數據idx和rcvd_idx一致才將數據返回。因此可能會存在如下這種情況:

假設num_workers=8,現在發送了8個數據給相應的worker,此時send_idx=8,rcvd_idx=0。過了一段時間以后,{1,2,3,5,6,7}進程數據准備完畢,此時主進程從data_queue讀取到相關的數據,但由於和rcvd_idx不匹配,只能將其放在緩存里。直到send_idx=0數據准備齊以后,才能將數據返回出去,隨后從緩存中彈出2,3的數據,之后又阻塞等待idx=4的數據。即輸出的數據必須保持順序性!因此在worker變多,出現這種逆序現象可能性會更大,這種現象也會出現在非num_workrers次迭代,只要相應的rcvd_idx沒有得到相關數據,則主進程就會一直等待。

 

d. process_next_batch

def _process_next_batch(self, batch):
    # 序號對上以后,rcvd_idx自加1
    self.rcvd_idx += 1
    # 添加一個fetchdata任務給worker
    self._put_indices()
    if isinstance(batch, ExceptionWrapper):
        raise batch.exc_type(batch.exc_msg)
    return batch

  

這個函數注意的是,只有在__next__中,idx == self.rcvd_idx時才會調用,也就是可能出現多個worker已經准備好了,但是只能放在緩存區,並且無法向index_queues塞入索引,使worker無法保持活躍狀態。

 

最后對於for循環從dataloader獲取data總體流程:

for epoch in range(num_epoches):
    for data in dataloader:

 

 

對於這個for,其實就是調用了dataloader 的__iter__() 方法, 產生了一個DataLoaderIter,如果是num_worker>0,init里就會創建多線程,並且有兩個隊列,一個是存放dataset的索引index_queues,一個是從index_queues里拿到索引,調用dataset的__getitem__()方法 (如果num_worker>0就多線程調用), 然后用collate_fn來把它們打包成batch,放到data_queue隊列里,反復調用DataLoaderIter 的__next__,從data_queue中獲取batch。

 

參考:

Pytorch數據讀取(Dataset, DataLoader, DataLoaderIter) https://zhuanlan.zhihu.com/p/30934236 

PyTorch 之 Dataset 和 Dataloader https://zhuanlan.zhihu.com/p/339675188

PyTorch36.DataLoader源代碼剖析 https://zhuanlan.zhihu.com/p/169497395

PyTorch DataLoader初探 https://zhuanlan.zhihu.com/p/91521705

一文弄懂Pytorch的DataLoader, DataSet, Sampler之間的關系 https://zhuanlan.zhihu.com/p/76893455


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM