SSD源碼解讀——數據讀取


之前,對SSD的論文進行了解讀,可以回顧之前的博客:https://www.cnblogs.com/dengshunge/p/11665929.html

為了加深對SSD的理解,因此對SSD的源碼進行了復現,主要參考的github項目是ssd.pytorch。同時,我自己對該項目增加了大量注釋:https://github.com/Dengshunge/mySSD_pytorch

搭建SSD的項目,可以分成以下三個部分:

  1. 數據讀取;
  2. 網絡搭建
  3. 損失函數的構建
  4. 網絡測試

接下來,本篇博客重點分析數據讀取


一、整體框架

SSD的數據讀取環節,同樣適用於大部分目標檢測的環節,具有通用性。為了方便理解,本項目以VOC2007+2012為例。因此,數據讀取環節,通常是按照以下步驟展開進行:

  1. 函數入口;
  2. 圖片的讀取和xml文件的讀取;
  3. 對GT框進行處理;
  4. 數據增強;
  5. 輔助函數。

二、具體實現細節

2.1 函數入口

數據讀取的函數入口在train.py文件中:

if args.dataset == 'VOC':
    train_dataset = VOCDetection(root=args.dataset_root)
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=args.batch_size, num_workers=4,
        collate_fn=detection_collate, shuffle=True, pin_memory=True)

可以看到,首先通過函數 VOCDetection() 來對VOC數據集進行初始化,再使用函數 DataLoader() 來實現對數據集的讀取。這一步與常見的分類網絡形式相同,但不同的是,多了collate_fn這一參數,后續會對此進行說明。

2.2 圖片與xml文件讀取

首先,我們先看看函數VOCDetection() 的初始化函數__init__()。在__init__中包含了需要傳入的幾個參數,image_sets(表示VOC使用到的數據集),transform(數據增強的方式),target_transform(GT框的處理方式)。

class VOCDetection():
    """VOC Detection Dataset Object

    input is image, target is annotation

    Arguments:
        root (string): filepath to VOCdevkit folder.
        image_set (string): imageset to use (eg. 'train', 'val', 'test')
        transform (callable, optional): transformation to perform on the input image
            圖片預處理的方式,這里使用了大量數據增強的方式
        target_transform (callable, optional): transformation to perform on the
            target `annotation`
            (eg: take in caption string, return tensor of word indices)
            真實框預處理的方式
    """

    def __init__(self, root,
                 image_sets=[('2007', 'trainval'), ('2012', 'trainval')],
                 transform=SSDAugmentation(size=config.voc['min_dim'], mean=config.MEANS),
                 target_transform=VOCAnnotationTransform()):
        self.root = root
        self.image_set = image_sets
        self.transform = transform
        self.target_transform = target_transform
        self._annopath = os.path.join('%s', 'Annotations', '%s.xml')
        self._imgpath = os.path.join('%s', 'JPEGImages', '%s.jpg')
        self.ids = []
        # 使用VOC2007和VOC2012的train作為訓練集
        for (year, name) in self.image_set:
            rootpath = os.path.join(self.root, 'VOC' + year)
            for line in open(os.path.join(rootpath, 'ImageSets', 'Main', name + '.txt')):
                self.ids.append([rootpath, line[:-1]])

首先,為什么image_sets是這樣的形式呢?因為VOC具有固定的文件夾路徑,利用這個參數和配合路徑讀取,可以讀取到txt文件,該txt文件用於制定哪些圖片用於訓練。此外,還需要設置參數self.ids,這個list用於存儲文件的路徑,由兩列組成,"VOC/2007"和圖片名稱。通過這兩個參數,后續可以配合函數_annopath()和_imgpath()可以讀取到對應圖片的路徑和xml文件。

在pytorch中,還需要相應的函數來對讀取圖片與返回結果,如下所示。其中,重點是pull_iterm函數。

    def __getitem__(self, index):
        im, gt = self.pull_item(index)
        return im, gt

    def __len__(self):
        return len(self.ids)

    def pull_item(self, index):
        img_id = tuple(self.ids[index])
        # img_id里面有2個值
        target = ET.parse(self._annopath % img_id).getroot()  # 獲得xml的內容,但這個是具有特殊格式的
        img = cv2.imread(self._imgpath % img_id)
        height, width, _ = img.shape

        if self.target_transform is not None:
            # 真實框處理
            target = self.target_transform(target, width, height)

        if self.transform is not None:
            # 圖像預處理,進行數據增強,只在訓練進行數據增強,測試的時候不需要
            target = np.array(target)
            img, boxes, labels = self.transform(img, target[:, :4], target[:, 4])
            # 轉換格式
            img = img[:, :, (2, 1, 0)]  # to rbg
            target = np.hstack((boxes, np.expand_dims(labels, axis=1)))
        return torch.from_numpy(img).permute(2, 0, 1), target

該函數pull_item(),首先讀取圖片和相應的xml文件;接着對使用類VOCAnnotationTransform來對GT框進行處理,即讀取GT框坐標與將坐標歸一化;然后通過函數SSDAugmentation()對圖片進行數據增強;最后對對圖片進行常規處理(交換通道等),返回圖片與存有GT框的list。

2.3 對GT框進行處理

接着,需要講一講這個類VOCAnnotationTransform的作用,其定義如下。self.class_to_ind是一個map,其key是類別名稱,value是編號,這個對象的作用是,讀取xml時,能將對應的類別名稱轉換成label;在__call__()函數中,主要是xml讀取的一些方式,值得一提的是,GT框的最表轉換成了[0,1]之間,當圖片尺寸變化了,GT框的坐標也能進行相應的變換。最后,res的每行由5個元素組成,分別是[x_min,y_min,x_max,y_max,label]。

class VOCAnnotationTransform():
    '''
    獲取xml里面的坐標值和label,並將坐標值轉換成0到1
    '''

    def __init__(self, class_to_ind=None, keep_difficult=False):
        # 將類別名字轉換成數字label
        self.class_to_ind = class_to_ind or dict(zip(VOC_CLASSES, range(len(VOC_CLASSES))))
        # 在xml里面,有個difficult的參數,這個表示特別難識別的目標,一般是小目標或者遮擋嚴重的目標
        # 因此,可以通過這個參數,忽略這些目標
        self.keep_difficult = keep_difficult

    def __call__(self, target, width, height):
        '''
        將一張圖里面包含若干個目標,獲取這些目標的坐標值,並轉換成0到1,並得到其label
        :param target: xml格式
        :return: 返回List,每個目標對應一行,每行包括5個參數[xmin, ymin, xmax, ymax, label_ind]
        '''
        res = []
        for obj in target.iter('object'):
            difficult = int(obj.find('difficult').text) == 1  # 判斷該目標是否為難例
            # 判斷是否跳過難例
            if not self.keep_difficult and difficult:
                continue
            name = obj.find('name').text.lower().strip()  # text是獲得目標的名稱,lower將字符轉換成小寫,strip去除前后空格
            bbox = obj.find('bndbox')  # 獲得真實框坐標

            pts = ['xmin', 'ymin', 'xmax', 'ymax']
            bndbox = []
            for i, pt in enumerate(pts):
                cur_pt = int(bbox.find(pt).text) - 1  # 獲得坐標值
                # 將坐標轉換成[0,1],這樣圖片尺寸發生變化的時候,真實框也隨之變化,即平移不變形
                cur_pt = cur_pt / width if i % 2 == 0 else cur_pt / height
                bndbox.append(cur_pt)
            label_idx = self.class_to_ind[name]  # 獲得名字對應的label
            bndbox.append(label_idx)
            res.append(bndbox)  # [xmin, ymin, xmax, ymax, label_ind]
        return res  # [[xmin, ymin, xmax, ymax, label_ind], ... ]

2.4 數據增強

還有一個重要的函數,即函數SSDAugmentation(),該函數的作用是作數據增強。論文中也提及了,數據增強對最終的結果提升有着重大作用。博客1博客2具體講述了數據增強的源碼,講得十分詳細。在本項目中,SSDAugmentation()函數在data/augmentations.py中,如下所示。由於opencv讀取讀片的時候,取值范圍是[0,255],是int類型,需要將其轉換為float類型,計算其GT框的正式坐標。然后對圖片進行光度變形,包含改變對比度,改變飽和度,改變色調、改變亮度和增加噪聲等。接着有對圖片進行擴張和裁剪等。在此操作中,會涉及到GT框坐標的變換。最后,當上述變化處理完后,再對GT框坐標歸一化,和resize圖片,減去均值等。具體細節,可以參考兩篇博客進行解讀。

class SSDAugmentation(object):
    def __init__(self, size=300, mean=(104, 117, 123)):
        self.mean = mean
        self.size = size
        self.augment = Compose([
            ConvertFromInts(),  # 將圖片從int轉換成float
            ToAbsoluteCoords(),  # 計算真實的錨點框坐標
            PhotometricDistort(),  # 光度變形
            Expand(self.mean),  # 隨機擴張圖片
            RandomSampleCrop(),  # 隨機裁剪
            RandomMirror(),  # 隨機鏡像
            ToPercentCoords(),
            Resize(self.size),
            SubtractMeans(self.mean)
        ])

    def __call__(self, img, boxes, labels):
        return self.augment(img, boxes, labels)

2.5 輔助函數

在一個batch中,每張圖片的GT框數量是不等的,因此,需要定義一個函數來處理這種情況。函數detection_collate()就是用於處理這種情況,使得一張圖片能對應一個list,這里list里面有所有GT框的信息組成。

def detection_collate(batch):
    """Custom collate fn for dealing with batches of images that have a different
    number of associated object annotations (bounding boxes).
    自定義處理在同一個batch,含有不同數量的目標框的情況

    Arguments:
        batch: (tuple) A tuple of tensor images and lists of annotations

    Return:
        A tuple containing:
            1) (tensor) batch of images stacked on their 0 dim
            2) (list of tensors) annotations for a given image are stacked on
                                 0 dim
    """
    targets = []
    imgs = []
    for sample in batch:
        imgs.append(sample[0])
        targets.append(torch.FloatTensor(sample[1]))
    return torch.stack(imgs, 0), targets

 


 

至此,已經將SSD的數據讀取部分分析完。 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM