Pytorch數據讀取詳解


原文:http://studyai.com/article/11efc2bf#采樣器 Sampler & BatchSampler

數據庫DataBase + 數據集DataSet + 采樣器Sampler = 加載器Loader

from torch.utils.data import *

IMDB + Dataset + Sampler || BatchSampler = DataLoader

數據庫 DataBase

Image DataBase 簡稱IMDB,指的是存儲在文件中的數據信息。

文件格式可以多種多樣。比如xml, yaml, json, sql.

VOC是xml格式的,COCO是JSON格式的。

構造IMDB的過程,就是解析這些文件,並建立數據索引的過程。

一般會被解析為Python列表, 以方便后續迭代讀取。

數據集 DataSet

數據集 DataSet: 在數據庫IMDB的基礎上,提供對數據的單例或切片訪問方法。

換言之,就是定義數據庫中對象的索引機制,如何實現單例索引或切片索引。

簡言之,DataSet,通過__getitem__定義了數據集DataSet是一個可索引對象,An Indexerable Object。

即傳入一個給定的索引Index之后,如何按此索引進行單例或切片訪問,單例還是切片視Index是單值還是列表。

Pytorch源碼如下:

class Dataset(object):
    """An abstract class representing a Dataset.
    All other datasets should subclass it. All subclasses should override
    ``__len__``, that provides the size of the dataset, and ``__getitem__``,
    supporting integer indexing in range from 0 to len(self) exclusive.
    """
    # 定義單例/切片訪問方法,即 dataItem = Dataset[index]
    def __getitem__(self, index):
        raise NotImplementedError
    def __len__(self):
        raise NotImplementedError
    def __add__(self, other):
        return ConcatDataset([self, other])

自定義數據集要基於上述Dataset基類、IMDB基類,有兩種方法。

# 方法一: 單繼承
class XxDataset(Dataset)
    # 將IMDB作為參數傳入,進行二次封裝
    imdb = IMDB()
    pass
# 方法二: 雙繼承
class XxDataset(IMDB, Dataset):
    pass

采樣器 Sampler & BatchSampler

在實際應用中,數據並不一定是循規蹈矩的序慣訪問,而需要隨機打亂順序來訪問,或需要隨機加權訪問,

因此,按某種特定的規則來讀取數據,就是采樣操作,需要定義采樣器:Sampler

另外,數據也可能並不是一個一個讀取的,而需要一批一批的讀取,即需要批量采樣操作,定義批量采樣器:BatchSampler

所以,只有Dataset的單例訪問方法還不夠,還需要在此基礎上,進一步的定義批量訪問方法。

簡言之,采樣器定義了索引(index)的產生規則,按指定規則去產生索引,從而控制數據的讀取機制

BatchSampler 是基於 Sampler 來構造的: BatchSampler = Sampler + BatchSize

Pytorch源碼如下,

class Sampler(object):
    """Base class for all Samplers.
    采樣器基類,可以基於此自定義采樣器。
    Every Sampler subclass has to provide an __iter__ method, providing a way
    to iterate over indices of dataset elements, and a __len__ method that
    returns the length of the returned iterators.
    """
    def __init__(self, data_source):
        pass
    def __iter__(self):
        raise NotImplementedError
    def __len__(self):
        raise NotImplementedError
# 序慣采樣
class SequentialSampler(Sampler):
    def __init__(self, data_source):
        self.data_source = data_source
    def __iter__(self):
        return iter(range(len(self.data_source)))
    def __len__(self):
        return len(self.data_source)
# 隨機采樣
class RandomSampler(Sampler):
    def __init__(self, data_source):
        self.data_source = data_source
    def __iter__(self):
        return iter(torch.randperm(len(self.data_source)).long())
    def __len__(self):
        return len(self.data_source)
# 隨機子采樣
class SubsetRandomSampler(Sampler):
    pass
# 加權隨機采樣
class WeightedRandomSampler(Sampler):
    pass
class BatchSampler(object):
    """Wraps another sampler to yield a mini-batch of indices.
    Args:
        sampler (Sampler): Base sampler.
        batch_size (int): Size of mini-batch.
        drop_last (bool): If ``True``, the sampler will drop the last batch if
            its size would be less than ``batch_size``
    Example:
        >>> list(BatchSampler(range(10), batch_size=3, drop_last=False))
        [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
        >>> list(BatchSampler(range(10), batch_size=3, drop_last=True))
        [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
    """
    def __init__(self, sampler, batch_size, drop_last):
        self.sampler = sampler  # ******
        self.batch_size = batch_size
        self.drop_last = drop_last
    def __iter__(self):
        batch = []
        for idx in self.sampler:
            batch.append(idx)
            if len(batch) == self.batch_size:
                yield batch
                batch = []
        if len(batch) > 0 and not self.drop_last:
            yield batch
    def __len__(self):
        if self.drop_last:
            return len(self.sampler) // self.batch_size
        else:
            return (len(self.sampler) + self.batch_size - 1) // self.batch_size

由上可見,Sampler本質就是個具有特定規則的可迭代對象,但只能單例迭代。

[x for x in range(10)], range(10)就是個最基本的Sampler,每次循環只能取出其中的一個值.

[x for x in range(10)]
Out[10]: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
from torch.utils.data.sampler import SequentialSampler
[x for x in SequentialSampler(range(10))]
Out[14]: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
from torch.utils.data.sampler import RandomSampler
[x for x in RandomSampler(range(10))]
Out[12]: [4, 9, 5, 0, 2, 8, 3, 1, 7, 6]

BatchSampler對Sampler進行二次封裝,引入了batchSize參數,實現了批量迭代。

from torch.utils.data.sampler import BatchSampler
[x for x in BatchSampler(range(10), batch_size=3, drop_last=False)]
Out[9]: [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
[x for x in BatchSampler(RandomSampler(range(10)), batch_size=3, drop_last=False)]
Out[15]: [[1, 3, 7], [9, 2, 0], [5, 4, 6], [8]]

加載器 DataLoader

在實際計算中,如果數據量很大,考慮到內存有限,且IO速度很慢,

因此不能一次性的將其全部加載到內存中,也不能只用一個線程去加載。

因而需要多線程、迭代加載, 因而專門定義加載器:DataLoader

DataLoader 是一個可迭代對象, An Iterable Object, 內部配置了魔法函數——iter——,調用它將返回一個迭代器。

該函數可用內置函數iter直接調用,即 DataIteror = iter(DataLoader)

dataloader = DataLoader(dataset=Dataset(imdb=IMDB()), sampler=Sampler(), num_works, ...)

__init__參數包含兩部分,前半部分用於指定數據集 + 采樣器,后半部分為多線程參數

class DataLoader(object):
    """
    Data loader. Combines a dataset and a sampler, and provides
    single- or multi-process iterators over the dataset.
    """
    def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None,
                 num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False,
                 timeout=0, worker_init_fn=None):
        self.dataset = dataset
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.collate_fn = collate_fn
        self.pin_memory = pin_memory
        self.drop_last = drop_last
        self.timeout = timeout
        self.worker_init_fn = worker_init_fn
        if timeout < 0:
            raise ValueError('timeout option should be non-negative')
        # 檢測是否存在參數沖突: 默認batchSampler vs 自定義BatchSampler
        if batch_sampler is not None:
            if batch_size > 1 or shuffle or sampler is not None or drop_last:
                raise ValueError('batch_sampler is mutually exclusive with '
                                 'batch_size, shuffle, sampler, and drop_last')
        if sampler is not None and shuffle:
            raise ValueError('sampler is mutually exclusive with shuffle')
        if self.num_workers < 0:
            raise ValueError('num_workers cannot be negative; '
                             'use num_workers=0 to disable multiprocessing.')
        # 在此處會強行指定一個 BatchSampler
        if batch_sampler is None:
            # 在此處會強行指定一個 Sampler
            if sampler is None:
                if shuffle:
                    sampler = RandomSampler(dataset)
                else:
                    sampler = SequentialSampler(dataset)
            batch_sampler = BatchSampler(sampler, batch_size, drop_last)
        # 使用自定義的采樣器和批采樣器
        self.sampler = sampler
        self.batch_sampler = batch_sampler
    def __iter__(self):
        # 調用Pytorch的多線程迭代器加載數據
        return DataLoaderIter(self)
    def __len__(self):
        return len(self.batch_sampler)

數據迭代器 DataLoaderIter

迭代器與可迭代對象之間是有區別的。

可迭代對象,意思是對其使用Iter函數時,它可以返回一個迭代器,從而可以連續的迭代訪問它。

迭代器對象,內部有額外的魔法函數__next__,用內置函數next作用其上,則可以連續產生下一個數據,產生規則即是由此函數來確定的。

可迭代對象描述了對象具有可迭代性,但具體的迭代規則由迭代器來描述,這樣解耦的好處是可以對同一個可迭代對象配置多種不同規則的迭代器。

數據集/容器遍歷的一般化流程:NILIS

NILIS規則: data = next(iter(loader(DataSet[sampler])))data=next(iter(loader(DataSet[sampler])))

  1. sampler 定義索引index的生成規則,返回一個index列表,控制后續的索引訪問過程。
  2. indexer 基於__item__在容器上定義按索引訪問的規則,讓容器成為可索引對象,可用[]操作。
  3. loader 基於__iter__在容器上定義可迭代性,描述加載規則,包括返回一個迭代器,讓容器成為可迭代對象, 可用iter()操作。
  4. next 基於__next__在容器上定義迭代器,描述具體的迭代規則,讓容器成為迭代器對象, 可用next()操作。
## 初始化
sampler = Sampler()
dataSet = DataSet(sampler)            # __getitem__
dataLoader = DataLoader(dataSet, sampler) / DataIterable()        # __iter__()
dataIterator = DataLoaderIter(dataLoader)     #__next__()
data_iter = iter(dataLoader)
## 遍歷方法1
for _ in range(len(data_iter))
    data = next(data_iter)
## 遍歷方法2
for i, data in enumerate(dataLoader):
    data = data


微信公眾號:AutoML機器學習
MARSGGBO原創
如有意合作或學術討論歡迎私戳聯系~
郵箱:marsggbo@foxmail.com





2019-8-4




免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM