構建data_loader原理步驟
# engine/default.py
from detectron2.data import (
MetadataCatalog,
build_detection_test_loader,
build_detection_train_loader,
)
class DefaultTrainer(SimpleTrainer):
def __init__(self, cfg):
# Assume these objects must be constructed in this order.
data_loader = self.build_train_loader(cfg)
...
@classmethod
def build_train_loader(cls, cfg):
"""
Returns:
iterable
"""
return build_detection_train_loader(cfg)
函數調用關系如下圖:
結合前面兩篇文章的內容可以看到detectron2在構建model,optimizer和data_loader的時候都是在對應的build.py
文件里實現的。我們看一下build_detection_train_loader
是如何定義的(對應上圖中紫色方框內的部分(自下往上的順序)):
def build_detection_train_loader(cfg, mapper=None):
"""
A data loader is created by the following steps:
1. Use the dataset names in config to query :class:`DatasetCatalog`, and obtain a list of dicts.
2. Start workers to work on the dicts. Each worker will:
* Map each metadata dict into another format to be consumed by the model.
* Batch them by simply putting dicts into a list.
The batched ``list[mapped_dict]`` is what this dataloader will return.
Args:
cfg (CfgNode): the config
mapper (callable): a callable which takes a sample (dict) from dataset and
returns the format to be consumed by the model.
By default it will be `DatasetMapper(cfg, True)`.
Returns:
a torch DataLoader object
"""
# 獲得dataset_dicts
dataset_dicts = get_detection_dataset_dicts(
cfg.DATASETS.TRAIN,
filter_empty=True,
min_keypoints=cfg.MODEL.ROI_KEYPOINT_HEAD.MIN_KEYPOINTS_PER_IMAGE
if cfg.MODEL.KEYPOINT_ON
else 0,
proposal_files=cfg.DATASETS.PROPOSAL_FILES_TRAIN if cfg.MODEL.LOAD_PROPOSALS else None,
)
# 將dataset_dicts轉化成torch.utils.data.Dataset
dataset = DatasetFromList(dataset_dicts, copy=False)
# 進一步轉化成MapDataset,每次讀取數據時都會調用mapper來對dict進行解析
if mapper is None:
mapper = DatasetMapper(cfg, True)
dataset = MapDataset(dataset, mapper)
# 采樣器
sampler_name = cfg.DATALOADER.SAMPLER_TRAIN
if sampler_name == "TrainingSampler":
sampler = samplers.TrainingSampler(len(dataset))
...
batch_sampler = build_batch_data_sampler(
sampler, images_per_worker, group_bin_edges, aspect_ratios
)
# 數據迭代器 data_loader
data_loader = torch.utils.data.DataLoader(
dataset,
num_workers=cfg.DATALOADER.NUM_WORKERS,
batch_sampler=batch_sampler,
collate_fn=trivial_batch_collator,
worker_init_fn=worker_init_reset_seed,
)
return data_loader
由上面的源代碼可以看出總共是五個步驟,我們只對前面三個部分進行詳細介紹,后面的采樣器和data_loader可以參閱一文弄懂Pytorch的DataLoader, DataSet, Sampler之間的關系。
獲得dataset_dicts
get_detection_dataset_dicts(dataset_names)
函數需要傳遞的一個重要參數是dataset_names
,這個參數其實就是一個字符串,用來指定數據集的名稱。通過這個字符串,該函數會調用data/catalog.py
的DatasetCatalog
類來進行解析得到一個包含數據信息的字典。
解析的原理是:DatasetCatalog
有一個字典_REGISTERED
,默認已經注冊好了例如coco,voc
這些數據集的信息。如果你想要使用你自己的數據集,那么你需要在最開始前你需要定義你的數據集名字以及定義一個函數(這個函數不需要傳參,而且最后會返回一個dict,該dict包含你的數據集信息),舉個栗子:
from detectron2.data import DatasetCatalog
my_dataset_name = 'apple'
def get_dicts():
...
return dict
DatasetCatalog.register(my_dataset_name, get_dicts)
當然,如果你的數據集已經是COCO的格式了,那么你也可以使用如下方法進行注冊:
from detectron2.data.datasets import register_coco_instances
my_dataset_name = 'apple'
register_coco_instances(my_dataset_name, {}, "json_annotation.json", "path/to/image/dir")
另外需要注意的是一個數據集其實是可以由兩個類來定義的,一個是前面介紹了的DatasetCatalog
,另一個是MetadataCatalog
。
MetadataCatalog
的作用是記錄數據集的一些特征,這樣我們就可以很方便的在整個代碼中獲取數據集的特征信息。在注冊DatasetCatalog
后,我們可以按如下栗子對MetadataCatalog
進行注冊並定義我們后面可能會用到的屬性特征:
from detectron2.data import MetadataCatalog
MetadataCatalog.get("my_dataset").thing_classes = ["person", "dog"]
# 也可以這樣
MetadataCatalog.get("my_dataset").set("thing_classes",["person", "dog"])
注意:如果你的數據集名字未注冊過,MetadataCatalog.get
會自動進行注冊,然后會自動設置你所設定的屬性值。
其實MetadataCatalog
還有其他的特征屬性可以設置,如stuff_classes
,stuff_colors
等等。你可能會好奇thing_classes
和stuff_classes
有什么區別,區別如下:
- 抽象解釋:
thing_classes
用於指定instance-level任務,stuff_classes
用於semantic segmentation任務。 - 具體解釋:像椅子,書這種可數的東西,就可以理解成
thing
,所以用於instance-level;而雪、天空這種不可數的就理解成stuff
,所以用於semantic segmentation。參考On Seeing Stuff: The Perception of Materials by Humans and Machines
最后,get_detection_dataset_dicts
會返回一個包含若干個dict的list,之所以是list是因為參數dataset_names
也是一個list,這樣我們就可以制定多個names來同時對數據進行讀取。
解析成DatasetFromList
DatasetFromList(dataset_dict)
函數定義在detectron2/data/common.py
中,它其實就是一個torch.utils.data.Dataset
類,其源碼如下
class DatasetFromList(data.Dataset):
"""
Wrap a list to a torch Dataset. It produces elements of the list as data.
"""
def __init__(self, lst: list, copy: bool = True):
"""
Args:
lst (list): a list which contains elements to produce.
copy (bool): whether to deepcopy the element when producing it,
so that the result can be modified in place without affecting the
source in the list.
"""
self._lst = lst
self._copy = copy
def __len__(self):
return len(self._lst)
def __getitem__(self, idx):
if self._copy:
return copy.deepcopy(self._lst[idx])
else:
return self._lst[idx]
這個很簡單就不加贅述了
將DatsetFromList
轉化成MapDataset
其實DatsetFromList
和MapDataset
都是torch.utils.data.Dataset
的子類,那他們的區別是什么呢?很簡單,區別就是后者使用了mapper
。
在解釋mapper
是什么之前我們首先要知道的是,在detectron2中,一張圖片對應的是一個dict,那么整個數據集就是list[dict]。之后我們再看DatsetFromList
,它的__getitem__
函數非常簡單,它只是簡單粗暴地就返回了指定idx的元素。顯然這樣是不行的,因為在把數據扔給模型訓練之前我們肯定還要對數據做一定的處理,而這個工作就是由mapper
來做的,默認情況下使用的是detectron2/data/dataset_mapper.py
中定義的DatasetMapper
,如果你需要自定義一個mapper
也可以參考這個寫。
DatasetMapper(cfg, is_train=True)
我們繼續了解一下DatasetMapper
的實現原理,首先看一下官方給的定義:
A callable which takes a dataset dict in Detectron2 Dataset format, and map it into a format used by the model.
簡單概括就是這個類是可調用的(callable),所以在下面的源碼中可以看到定義了__call__
方法。
該類主要做了這三件事:
The callable currently does the following:
- Read the image from "file_name"
- Applies cropping/geometric transforms to the image and annotations
- Prepare data and annotations to Tensor and :class:
Instances
其源碼如下(有刪減):
class DatasetMapper:
def __init__(self, cfg, is_train=True):
# 讀取cfg的參數
...
def __call__(self, dataset_dict):
"""
Args:
dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format.
Returns:
dict: a format that builtin models in detectron2 accept
"""
dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below
# 1. 讀取圖像數據
image = utils.read_image(dataset_dict["file_name"], format=self.img_format)
# 2. 對image和box等做Transformation
if "annotations" not in dataset_dict:
image, transforms = T.apply_transform_gens(
([self.crop_gen] if self.crop_gen else []) + self.tfm_gens, image
)
else:
...
image, transforms = T.apply_transform_gens(self.tfm_gens, image)
if self.crop_gen:
transforms = crop_tfm + transforms
image_shape = image.shape[:2] # h, w
# 3.將數據轉化成tensor格式
dataset_dict["image"] = torch.as_tensor(image.transpose(2, 0, 1).astype("float32"))
...
return dataset_dict
MapDataset
class MapDataset(data.Dataset):
def __init__(self, dataset, map_func):
self._dataset = dataset
self._map_func = PicklableWrapper(map_func) # wrap so that a lambda will work
self._rng = random.Random(42)
self._fallback_candidates = set(range(len(dataset)))
def __len__(self):
return len(self._dataset)
def __getitem__(self, idx):
retry_count = 0
cur_idx = int(idx)
while True:
data = self._map_func(self._dataset[cur_idx])
if data is not None:
self._fallback_candidates.add(cur_idx)
return data
# _map_func fails for this idx, use a random new index from the pool
retry_count += 1
self._fallback_candidates.discard(cur_idx)
cur_idx = self._rng.sample(self._fallback_candidates, k=1)[0]
if retry_count >= 3:
logger = logging.getLogger(__name__)
logger.warning(
"Failed to apply `_map_func` for idx: {}, retry count: {}".format(
idx, retry_count
)
)
self._fallback_candidates
是一個set
,它的特點是其中的元素是獨一無二的,定義這個的作用是記錄可正常讀取的數據索引,因為有的數據可能無法正常讀取,所以這個時候我們就可以把這個壞數據的索引從_fallback_candidates
中剔除,並隨機采樣一個索引來讀取數據。__getitem__
中的邏輯就是首先讀取指定索引的數據,如果正常讀取就把該所索引值加入到_fallback_candidates
中去;反之,如果數據無法讀取,則將對應索引值刪除,並隨機采樣一個數據,並且嘗試3次,若3次后都無法正常讀取數據,則報錯,但是好像也沒有退出程序,而是繼續讀數據,可能是以為總有能正常讀取的數據吧hhh。
如有意合作,歡迎私戳
郵箱:marsggbo@foxmail.com
2019-10-23 13:37:13