pytorch 是應用非常廣泛的深度學習框架,模型訓練的第一步就是數據集的創建。
pytorch 可訓練數據集創建的操作步驟如下:
1.創建一個Dataset對象
2.創建一個DataLoader對象
3.循環這個DataLoder對象,將data,label加載到模型中訓練
其中Dataset和Dataloader的創建就要用到pytorch的torch.utils.data 中的Dataset類和DataLoader類。
首先看一下torch.utils.data.Dataset 的源碼
class Dataset(object): """An abstract class representing a Dataset. All other datasets should subclass it. All subclasses should override ``__len__``, that provides the size of the dataset, and ``__getitem__``, supporting integer indexing in range from 0 to len(self) exclusive. """ def __getitem__(self, index): raise NotImplementedError def __len__(self): raise NotImplementedError def __add__(self, other): return ConcatDataset([self, other])
torch.utils.data.Dataset 是代表自定義數據集的抽象類,我們可以定義自己的數據類抽象這個類,只需要重寫__len__和__getitem__這兩個方法就可以。
作用:
__len__(self)獲取數據集的長度
__geiitem__(self, index)函數來根據索引號獲取圖片和標簽
通常我們按如下方式定義自己的數據類
# 導入必要的包
import torch import os import torch.utils.data as data from PIL import Image import numpy as np import random # 自定義數據集,繼承Dataset父類 class VeriDataset(data.Dataset):
#初始化, (圖片文件路徑, txt文件, transfrom ...) def __init__(self, data_dir, train_list, train_data_transform=None, is_train=True): ''' data_dir: 圖像文件根目錄 train_list: 圖像名稱txt文件 train_data_transform: 圖像預處理 is_train: 訓練集集驗證集標志 ''' super(VeriDataset, self).__init__() self.is_train = is_train self.data_dir = data_dir self.train_data_transform = train_data_transform #讀取.txt文件 f = open(train_list, 'r') lines = f.readlines() f.close() self.names = [] self.labels = [] self.cams = [] if is_train == True: i = 0 for line in lines: if i % 10000 == 0: print(line) line = line.strip().split(' ') self.names.append(line[0]) self.labels.append(line[1]) self.cams.append(line[0].split('_')[1]) i += 1 # 訓練集、驗證集的文件儲存不一樣 else: for line in lines: line = line.strip() self.names.append(line) self.labels.append(line.split('_')[0]) self.cams.append(line.split('_')[1]) # self.labels = np.array(self.labels, dtype=np.float32)
#根據索引獲取圖片和標簽(重寫父類函數)
def __getitem__(self, index): ''' index 自動+1 ''' img = Image.open(os.path.join(self.data_dir, self.names[index])).convert('RGB') # print("圖像數據已輸入") target = int(self.labels[index]) camid = self.cams[index] if self.train_data_transform != None: img = self.train_data_transform(img) return img, target, camid
# 返回數據集大小(重寫父類函數)
def __len__(self): return len(self.names)
創建一個DataLoader對象
DataLoader也是pytorch的重要接口,該接口可以將自定義的Dataset 根據batch_size大小、是否shuffle等封裝成一個BatchSize大小的Tensor,用於后面訓練。
看一下DataLoader的源碼
class DataLoader(object): r""" Data loader. Combines a dataset and a sampler, and provides single- or multi-process iterators over the dataset. Arguments: dataset (Dataset): dataset from which to load the data. batch_size (int, optional): how many samples per batch to load (default: 1). shuffle (bool, optional): set to ``True`` to have the data reshuffled at every epoch (default: False). sampler (Sampler, optional): defines the strategy to draw samples from the dataset. If specified, ``shuffle`` must be False. batch_sampler (Sampler, optional): like sampler, but returns a batch of indices at a time. Mutually exclusive with batch_size, shuffle, sampler, and drop_last. num_workers (int, optional): how many subprocesses to use for data loading. 0 means that the data will be loaded in the main process. (default: 0) collate_fn (callable, optional): merges a list of samples to form a mini-batch. pin_memory (bool, optional): If ``True``, the data loader will copy tensors into CUDA pinned memory before returning them. drop_last (bool, optional): set to ``True`` to drop the last incomplete batch, if the dataset size is not divisible by the batch size. If ``False`` and the size of dataset is not divisible by the batch size, then the last batch will be smaller. (default: False) timeout (numeric, optional): if positive, the timeout value for collecting a batch from workers. Should always be non-negative. (default: 0) worker_init_fn (callable, optional): If not None, this will be called on each worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as input, after seeding and before data loading. (default: None) .. note:: By default, each worker will have its PyTorch seed set to ``base_seed + worker_id``, where ``base_seed`` is a long generated by main process using its RNG. However, seeds for other libraies may be duplicated upon initializing workers (w.g., NumPy), causing each worker to return identical random numbers. (See :ref:`dataloader-workers-random-seed` section in FAQ.) You may use ``torch.initial_seed()`` to access the PyTorch seed for each worker in :attr:`worker_init_fn`, and use it to set other seeds before data loading. .. warning:: If ``spawn`` start method is used, :attr:`worker_init_fn` cannot be an unpicklable object, e.g., a lambda function. """ __initialized = False def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None): self.dataset = dataset self.batch_size = batch_size self.num_workers = num_workers self.collate_fn = collate_fn self.pin_memory = pin_memory self.drop_last = drop_last self.timeout = timeout self.worker_init_fn = worker_init_fn if timeout < 0: raise ValueError('timeout option should be non-negative') if batch_sampler is not None: if batch_size > 1 or shuffle or sampler is not None or drop_last: raise ValueError('batch_sampler option is mutually exclusive ' 'with batch_size, shuffle, sampler, and ' 'drop_last') self.batch_size = None self.drop_last = None if sampler is not None and shuffle: raise ValueError('sampler option is mutually exclusive with ' 'shuffle') if self.num_workers < 0: raise ValueError('num_workers option cannot be negative; ' 'use num_workers=0 to disable multiprocessing.') if batch_sampler is None: if sampler is None: if shuffle: sampler = RandomSampler(dataset) //將list打亂 else: sampler = SequentialSampler(dataset) batch_sampler = BatchSampler(sampler, batch_size, drop_last) self.sampler = sampler self.batch_sampler = batch_sampler self.__initialized = True def __setattr__(self, attr, val): if self.__initialized and attr in ('batch_size', 'sampler', 'drop_last'): raise ValueError('{} attribute should not be set after {} is ' 'initialized'.format(attr, self.__class__.__name__)) super(DataLoader, self).__setattr__(attr, val) def __iter__(self): return _DataLoaderIter(self) def __len__(self): return len(self.batch_sampler)
一般我們關心的有以下幾個參數
dataset:傳入的數據集
batch_size:每個batch有多少個樣本
shuffle:在每個epoch開始的時候,對數據進行重排
num_workers:有幾個進程來處理data_loading
生成dataloader和我們通常會用for循環遍歷數據進行訓練
因為DataLoader只有__iter__()而沒有實現__next__(),所以DataLoader是一個iterable而不是iteraror。所以__iter__()需要返回一個迭代器,_DataLoaderIter。
在_DataLoaderIter中實現了__next__()方法。
class _DataLoaderIter(object): "Iterates once over the DataLoader's dataset, as specified by the sampler" def __init__(self, loader): self.dataset = loader.dataset self.collate_fn = loader.collate_fn self.batch_sampler = loader.batch_sampler self.num_workers = loader.num_workers self.pin_memory = loader.pin_memory and torch.cuda.is_available() self.timeout = loader.timeout self.done_event = threading.Event() self.sample_iter = iter(self.batch_sampler) if self.num_workers > 0: self.worker_init_fn = loader.worker_init_fn self.index_queue = multiprocessing.SimpleQueue() self.worker_result_queue = multiprocessing.SimpleQueue() self.batches_outstanding = 0 self.worker_pids_set = False self.shutdown = False self.send_idx = 0 self.rcvd_idx = 0 self.reorder_dict = {} base_seed = torch.LongTensor(1).random_()[0] self.workers = [ multiprocessing.Process( target=_worker_loop, args=(self.dataset, self.index_queue, self.worker_result_queue, self.collate_fn, base_seed + i, self.worker_init_fn, i)) for i in range(self.num_workers)] if self.pin_memory or self.timeout > 0: self.data_queue = queue.Queue() self.worker_manager_thread = threading.Thread( target=_worker_manager_loop, args=(self.worker_result_queue, self.data_queue, self.done_event, self.pin_memory, torch.cuda.current_device())) self.worker_manager_thread.daemon = True self.worker_manager_thread.start() else: self.data_queue = self.worker_result_queue for w in self.workers: w.daemon = True # ensure that the worker exits on process exit w.start() _update_worker_pids(id(self), tuple(w.pid for w in self.workers)) _set_SIGCHLD_handler() self.worker_pids_set = True # prime the prefetch loop for _ in range(2 * self.num_workers): self._put_indices() def __next__(self): if self.num_workers == 0: # same-process loading indices = next(self.sample_iter) # may raise StopIteration batch = self.collate_fn([self.dataset[i] for i in indices]) if self.pin_memory: batch = pin_memory_batch(batch) return batch # check if the next sample has already been generated if self.rcvd_idx in self.reorder_dict: batch = self.reorder_dict.pop(self.rcvd_idx) return self._process_next_batch(batch) if self.batches_outstanding == 0: self._shutdown_workers() raise StopIteration while True: assert (not self.shutdown and self.batches_outstanding > 0) idx, batch = self._get_batch() self.batches_outstanding -= 1 if idx != self.rcvd_idx: # store out-of-order samples self.reorder_dict[idx] = batch continue return self._process_next_batch(batch)
def __iter__(self): return self
實例如下
from torch.utils.data import DataLoader train_loader = DataLoader(dataset=train_set, num_workers=4, batch_size=16, shuffle=True) for i, (image, target, camid) in enumerate(train_loader): batch_size = image.size(0) target = target.cuda() #target為tuple #轉化為GPU格式 # volatile 失效 image = torch.autograd.Variable(image, volatile=True).cuda() mage = image.cuda() with torch.no_grad(): image = torch.autograd.Variable(image).cuda() output, feat = model(image)
參考博客:
https://www.cnblogs.com/ytxwzqin/p/13086436.html
https://blog.csdn.net/qq_36653505/article/details/83351808
https://blog.csdn.net/tsq292978891/article/details/79414512?utm_medium=distribute.pc_relevant_t0.none-task-blog-BlogCommendFromMachineLearnPai2-1.channel_param&depth_1-utm_source=distribute.pc_relevant_t0.none-task-blog-BlogCommendFromMachineLearnPai2-1.channel_param