Pytorch自定義數據庫


1)前言

雖然torchvision.datasets中已經封裝了好多通用的數據集,但是我們在使用Pytorch做深度學習任務的時候,會面臨着自定義數據庫來滿足自己的任務需要。如我們要訓練一個人臉關鍵點檢測算法,提供的訓練數據標注如下形式,存在CSV文件中:

image_name,part_0_x,part_0_y,part_1_x,part_1_y,part_2_x, ... ,part_67_x,part_67_y
0805personali01.jpg,27,83,27,98, ... 84,134
1084239450_e76e00b7e7.jpg,70,236,71,257, ... ,128,312

在本次教程中,我們需要用到兩個額外的包:

  • scikit-image: 用於圖片io轉換
  • pandas: 用於解析csv文件

首先學習如何使用pandas庫解析csv文件

import pandas as pd
landmarks_frame = pd.read_csv('data/faces/face_landmarks.csv') n = 65 img_name = landmarks_frame.iloc[n, 0] landmarks = landmarks_frame.iloc[n, 1:].as_matrix() landmarks = landmarks.astype('float').reshape(-1, 2) print('Image name: {}'.format(img_name)) print('Landmarks shape: {}'.format(landmarks.shape)) print('First 4 Landmarks: {}'.format(landmarks[:4]))

2)自定義數據庫

torch.utils.data.Dataset是一個表示數據庫的抽象類,自定義數據庫需要繼承這個類,並且重寫其以下方法:

__len__ :返回數據庫的大小.
__getitem__ :支持使用下標的方式 如dataset[i] 來獲取第i個樣本

以下創建人臉特征點檢測的數據庫。我們將在__init__中解析csv文件,而在__getitem__中讀取圖片。這樣可以在需要圖片是才加載,內存效率高。此外,我們還可以先將數據集封裝成lmdb數據庫,讀取速度更快。

 

import torch.utils.data.Dataset as Dataset
class FaceLandmarksDataset(Dataset):
    """Face Landmarks dataset."""

    def __init__(self, csv_file, root_dir, transform=None):
        """
        Args:
            csv_file (string): 到達標注文件cvs的路徑.
            root_dir (string): 所有圖片的根目錄.
            transform (callable, optional): (可選參數)對每一個樣本進行轉換.
        """
        self.landmarks_frame = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.landmarks_frame)

    def __getitem__(self, idx):
        img_name = os.path.join(self.root_dir,self.landmarks_frame.iloc[idx, 0]) #第idx條數據的第一個字段,即文件名稱
        image = io.imread(img_name)                           #讀取圖像數據
        landmarks = self.landmarks_frame.iloc[idx, 1:].as_matrix() #讀取第idx條數據的第二個字段及其之后的所有字段,即所有關鍵點的坐標。然后轉成矩陣形式
        landmarks = landmarks.astype('float').reshape(-1, 2)  #將矩陣reshape成n行兩列矩陣
        sample = {'image': image, 'landmarks': landmarks}     #封裝數據

        if self.transform:
            sample = self.transform(sample)                   #數據轉換

        return sample                                         #返回數據

注:__getitem__每次只返回一個條數據,至於batch的封裝可以在DataLoader中設置batchsize,至於讀取速度可以設置num_worker。

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM