PyTorch源碼解讀之torch.utils.data.DataLoader(轉)


原文鏈接

https://blog.csdn.net/u014380165/article/details/79058479

寫得特別好!最近正好在學習pytorch,學習一下!

PyTorch中數據讀取的一個重要接口是torch.utils.data.DataLoader,該接口定義在dataloader.py腳本中,只要是用PyTorch來訓練模型基本都會用到該接口,該接口主要用來將自定義的數據讀取接口的輸出或者PyTorch已有的數據讀取接口的輸入按照batch size封裝成Tensor,后續只需要再包裝成Variable即可作為模型的輸入,因此該接口有點承上啟下的作用,比較重要。這篇博客介紹該接口的源碼,主要包含DataLoader和DataLoaderIter兩個類
dataloader.py腳本的的github地址https://github.com/pytorch/pytorch/blob/master/torch/utils/data/dataloader.py

DataLoader類源碼如下。先看看__init__中的幾個重要的輸入:1、dataset,這個就是PyTorch已有的數據讀取接口(比如torchvision.datasets.ImageFolder)或者自定義的數據接口的輸出,該輸出要么是torch.utils.data.Dataset類的對象,要么是繼承自torch.utils.data.Dataset類的自定義類的對象。2、batch_size,根據具體情況設置即可。3、shuffle,一般在訓練數據中會采用。4、collate_fn,是用來處理不同情況下的輸入dataset的封裝,一般采用默認即可,除非你自定義的數據讀取輸出非常少見。5、batch_sampler,從注釋可以看出,其和batch_size、shuffle等參數是互斥的,一般采用默認。6、sampler,從代碼可以看出,其和shuffle是互斥的,一般默認即可。7、num_workers,從注釋可以看出這個參數必須大於等於0,0的話表示數據導入在主進程中進行,其他大於0的數表示通過多個進程來導入數據,可以加快數據導入速度。8、pin_memory,注釋寫得很清楚了: pin_memory (bool, optional): If True, the data loader will copy tensors into CUDA pinned memory before returning them. 也就是一個數據拷貝的問題。9、timeout,是用來設置數據讀取的超時時間的,但超過這個時間還沒讀取到數據的話就會報錯。
__init__中,RandomSampler類表示隨機采樣且不重復,所以起到的就是shuffle的作用。BatchSampler類則是把batch size個RandomSampler類對象封裝成一個,這樣就實現了隨機選取一個batch的目的。這兩個采樣類都是定義在sampler.py腳本中,地址:https://github.com/pytorch/pytorch/blob/master/torch/utils/data/sampler.py。以上這些都是初始化的時候進行的。當代碼運行到要從torch.utils.data.DataLoader類生成的對象中取數據的時候,比如:
train_data=torch.utils.data.DataLoader(...)
for i, (input, target) in enumerate(train_data):
...
就會調用DataLoader類的__iter__方法,__iter__方法就一行代碼:return DataLoaderIter(self),輸入正是DataLoader類的屬性。因此當調用__iter__方法的時候就牽扯到另外一個類:DataLoaderIter,接下來介紹。

class DataLoader(object):
""" Data loader. Combines a dataset and a sampler, and provides single- or multi-process iterators over the dataset. Arguments: dataset (Dataset): dataset from which to load the data. batch_size (int, optional): how many samples per batch to load (default: 1). shuffle (bool, optional): set to ``True`` to have the data reshuffled at every epoch (default: False). sampler (Sampler, optional): defines the strategy to draw samples from the dataset. If specified, ``shuffle`` must be False. batch_sampler (Sampler, optional): like sampler, but returns a batch of indices at a time. Mutually exclusive with batch_size, shuffle, sampler, and drop_last. num_workers (int, optional): how many subprocesses to use for data loading. 0 means that the data will be loaded in the main process. (default: 0) collate_fn (callable, optional): merges a list of samples to form a mini-batch. pin_memory (bool, optional): If ``True``, the data loader will copy tensors into CUDA pinned memory before returning them. drop_last (bool, optional): set to ``True`` to drop the last incomplete batch, if the dataset size is not divisible by the batch size. If ``False`` and the size of dataset is not divisible by the batch size, then the last batch will be smaller. (default: False) timeout (numeric, optional): if positive, the timeout value for collecting a batch from workers. Should always be non-negative. (default: 0) worker_init_fn (callable, optional): If not None, this will be called on each worker subprocess with the worker id as input, after seeding and before data loading. (default: None) """

    def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None):
        self.dataset = dataset
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.collate_fn = collate_fn
        self.pin_memory = pin_memory
        self.drop_last = drop_last
        self.timeout = timeout
        self.worker_init_fn = worker_init_fn

        if timeout < 0:
            raise ValueError('timeout option should be non-negative')

        if batch_sampler is not None:
            if batch_size > 1 or shuffle or sampler is not None or drop_last:
                raise ValueError('batch_sampler is mutually exclusive with '
                                 'batch_size, shuffle, sampler, and drop_last')

        if sampler is not None and shuffle:
            raise ValueError('sampler is mutually exclusive with shuffle')

        if self.num_workers < 0:
            raise ValueError('num_workers cannot be negative; '
                             'use num_workers=0 to disable multiprocessing.')

        if batch_sampler is None:
            if sampler is None:
                if shuffle:
                    sampler = RandomSampler(dataset)
                else:
                    sampler = SequentialSampler(dataset)
            batch_sampler = BatchSampler(sampler, batch_size, drop_last)

        self.sampler = sampler
        self.batch_sampler = batch_sampler

    def __iter__(self):
        return DataLoaderIter(self)

    def __len__(self):
        return len(self.batch_sampler)

DataLoaderIter類源碼如下。self.index_queue = multiprocessing.SimpleQueue()中的multiprocessing是Python中的多進程管理包,而threading則是Python中的多線程管理包,二者很大一部分的接口用法類似。還是照例先看看__init__,前面部分都是一些賦值操作,比較特殊的是self.sample_iter = iter(self.batch_sampler),得到的self.sample_iter可以通過next(self.sample_iter)來獲取batch size個數據的index。self.rcvd_idx表示讀取到的一個batch數據的index,初始化為0,該值在迭代讀取數據的時候會用到。if self.num_workers語句是針對多進程或單進程的情況進行初始化,如果不是設置為多進程讀取數據,那么就不需要這些初始化操作,后面會介紹單進程數據讀取。在if語句中通過multiprocessing.SimpleQueue()類創建了一個簡單的隊列對象。multiprocessing.Process類就是構造進程的類,這里根據設定的進程數來啟動,然后賦值給self.workers。接下來的一個for循環就通過調用start方法依次啟動self.workers中的進程。接下來關於self.pin_memory的判斷語句,該判斷語句內部主要是實現了多線程操作。self.pin_memory的含義在前面已經介紹過了,當為True的時候,就會把數據拷到CUDA中。self.data_queue = queue.Queue()是通過Python的queue模塊初始化得到一個先進先出的隊列(queue模塊也可以初始化得到先進后出的隊列,需要用queue.LifoQueue()初始化),queue模塊主要應用在多線程讀取數據中。在threading.Thread的args參數中,第一個參數in_data就是一個進程的數據,一個進程中不同線程的數據也是通過隊列來維護的,這里采用的是Python的queue模塊來初始化得到一個隊列:queue.Queue()。初始化結束后,就會調用__next__方法,接下來介紹。
總的來說,如果設置為多進程讀取數據,那么就會采用隊列的方式來讀,如果不是采用多進程來讀取數據,那就采用普通方式來讀

class DataLoaderIter(object):
    "Iterates once over the DataLoader's dataset, as specified by the sampler"

    def __init__(self, loader):
        self.dataset = loader.dataset
        self.collate_fn = loader.collate_fn
        self.batch_sampler = loader.batch_sampler
        self.num_workers = loader.num_workers
        self.pin_memory = loader.pin_memory and torch.cuda.is_available()
        self.timeout = loader.timeout
        self.done_event = threading.Event()

        self.sample_iter = iter(self.batch_sampler)

        if self.num_workers > 0:
            self.worker_init_fn = loader.worker_init_fn
            self.index_queue = multiprocessing.SimpleQueue()
            self.worker_result_queue = multiprocessing.SimpleQueue()
            self.batches_outstanding = 0
            self.worker_pids_set = False
            self.shutdown = False
            self.send_idx = 0
            self.rcvd_idx = 0
            self.reorder_dict = {}

            base_seed = torch.LongTensor(1).random_()[0]
            self.workers = [
                multiprocessing.Process(
                    target=_worker_loop,
                    args=(self.dataset, self.index_queue, self.worker_result_queue, self.collate_fn,
                          base_seed + i, self.worker_init_fn, i))
                for i in range(self.num_workers)]

            if self.pin_memory or self.timeout > 0:
                self.data_queue = queue.Queue()
                self.worker_manager_thread = threading.Thread(
                    target=_worker_manager_loop,
                    args=(self.worker_result_queue, self.data_queue, self.done_event, self.pin_memory,
                          torch.cuda.current_device()))
                self.worker_manager_thread.daemon = True
                self.worker_manager_thread.start()
            else:
                self.data_queue = self.worker_result_queue

            for w in self.workers:
                w.daemon = True  # ensure that the worker exits on process exit
                w.start()

            _update_worker_pids(id(self), tuple(w.pid for w in self.workers))
            _set_SIGCHLD_handler()
            self.worker_pids_set = True

            # prime the prefetch loop
            for _ in range(2 * self.num_workers):
                self._put_indices()

DataLoaderIter類的__next__方法如下,包含3個if語句和1個while語句。
第一個if語句是用來處理self.num_workers等於0的情況,也就是不采用多進程進行數據讀取,可以看出在這個if語句中先通過indices = next(self.sample_iter)獲取長度為batch size的列表:indices,這個列表的每個值表示一個batch中每個數據的index,每執行一次next操作都會讀取一批長度為batch size的indices列表。然后通過self.collate_fn函數將batch size個tuple(每個tuple長度為2,其中第一個值是數據,Tensor類型,第二個值是標簽,int類型)封裝成一個list,這個list長度為2,兩個值都是Tensor,一個是batch size個數據組成的FloatTensor,另一個是batch size個標簽組成的LongTensor。所以簡單講self.collate_fn函數就是將batch size個分散的Tensor封裝成一個Tensor。batch = pin_memory_batch(batch)中pin_memory_batch函數的作用就是將輸入batch的每個Tensor都拷貝到CUDA中,該函數后面會詳細介紹。
第二個if語句是判斷當前想要讀取的batch的index(self.rcvd_idx)是否之前已經讀出來過(已讀出來的index和batch數據保存在self.reorder_dict字典中,可以結合最后的while語句一起看,因為self.reorder_dict字典的更新是在最后的while語句中),如果之前已經讀取過了,就根據這個index從reorder_dict字典中彈出對應的數據。最后返回batch數據的時候是 return self._process_next_batch(batch),該方法后面會詳細介紹。主要做是獲取下一個batch的數據index信息。
第三個if語句,self.batches_outstanding的值在前面初始中調用self._put_indices()方法時修改了,所以假設你的進程數self.num_workers設置為3,那么這里self.batches_outstanding就是3*2=6,可具體看self._put_indices()方法。
最后的while循環就是真正用來從隊列中讀取數據的操作,最主要的就是idx, batch = self._get_batch(),通過調用_get_batch()方法來讀取,后面有介紹,簡單講就是調用了隊列的get方法得到下一個batch的數據,得到的batch一般是長度為2的列表,列表的兩個值都是Tensor,分別表示數據(是一個batch的)和標簽。_get_batch()方法除了返回batch數據外,還得到另一個輸出:idx,這個輸出表示batch的index,這個if idx != self.rcvd_idx條件語句表示如果你讀取到的batch的index不等於當前想要的index:selg,rcvd_idx,那么就將讀取到的數據保存在字典self.reorder_dict中:self.reorder_dict[idx] = batch,然后繼續讀取數據,直到讀取到的數據的index等於self.rcvd_idx。

    def __next__(self):
        if self.num_workers == 0:  # same-process loading
            indices = next(self.sample_iter)  # may raise StopIteration
            batch = self.collate_fn([self.dataset[i] for i in indices])
            if self.pin_memory:
                batch = pin_memory_batch(batch)
            return batch

        # check if the next sample has already been generated
        if self.rcvd_idx in self.reorder_dict:
            batch = self.reorder_dict.pop(self.rcvd_idx)
            return self._process_next_batch(batch)

        if self.batches_outstanding == 0:
            self._shutdown_workers()
            raise StopIteration

        while True:
            assert (not self.shutdown and self.batches_outstanding > 0)
            idx, batch = self._get_batch()
            self.batches_outstanding -= 1
            if idx != self.rcvd_idx:
                # store out-of-order samples
                self.reorder_dict[idx] = batch
                continue
            return self._process_next_batch(batch)

pin_memory_batch函數不是定義在DataLoader類或DataLoaderIter類中。該函數主要是對batch中的Tensor執行batch.pin_memory()操作,這里的很多條件語句只是用來判斷batch的類型,假如batch是一個列表,列表中的每個值是Tensor,那么就會執行 elif isinstance(batch, collections.Sequence):這個條件,從而遍歷該列表中的每個Tensor,然后執行第一個條件語句的內容: return batch.pin_memory()

def pin_memory_batch(batch):
    if torch.is_tensor(batch):
        return batch.pin_memory()
    elif isinstance(batch, string_classes):
        return batch
    elif isinstance(batch, collections.Mapping):
        return {k: pin_memory_batch(sample) for k, sample in batch.items()}
    elif isinstance(batch, collections.Sequence):
        return [pin_memory_batch(sample) for sample in batch]
    else:
        return batch
iter() 與 next()的調用方法
dataiter = iter(trainloader) # 每次迭代取的是一個batch images, labels = dataiter.next() # 如果batch_size為4,則取出來的images是4×c×h×w的tensor,labels是1×4的向量

DataloaderIter類的_get_batch方法。主要根據是否設置了超時時間來操作,如果超過指定的超時時間后沒有從隊列中讀到數據就報錯,如果不設置超時時間且一致沒有從隊列中讀到數據,那么就會一直卡着且不報錯,這部分是PyTorch后來修的一個bug。

    def _get_batch(self):
        if self.timeout > 0:
            try:
                return self.data_queue.get(True, self.timeout)
            except queue.Empty:
                raise RuntimeError('DataLoader timed out after {} seconds'.format(self.timeout))
        else:
            return self.data_queue.get()

DataLoaderIter類的_process_next_batch方法。首先對self.rcvd_idx進行加一,也就是更新下下一個要讀取的batch數據的index。然后調用_put_indices()方法獲取下一個batch的每個數據的index。

   def _process_next_batch(self, batch):
        self.rcvd_idx += 1
        self._put_indices()
        if isinstance(batch, ExceptionWrapper):
            raise batch.exc_type(batch.exc_msg)
        return batch

DataLoaderIter類的_put_indices方法。該方法主要實現從self.sample_iter中讀取下一個batch數據中每個數據的index:indices = next(self.sample_iter, None),注意這里的index和前面idx是不一樣的,這里的index是一個batch中每個數據的index,idx是一個batch的index;然后將讀取到的index通過調用queue對象的put方法壓到隊列self.index_queue中:self.index_queue.put((self.send_idx, indices))

def _put_indices(self):
        assert self.batches_outstanding < 2 * self.num_workers
        indices = next(self.sample_iter, None)
        if indices is None:
            return
        self.index_queue.put((self.send_idx, indices))
        self.batches_outstanding += 1
        self.send_idx += 1


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM