本文簡單描述如果自定義dataset,代碼並未經過測試(只是說明思路),為半偽代碼。所有邏輯需按自己需求另外實現:
一、分析DataLoader
train_loader = DataLoader( datasets.MNIST('../data', train=True, download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ])), batch_size=batch_size, shuffle=True)
datasets.MNIST()是一個torch.utils.data.Datasets對象,batch_size表示我們定義的batch大小(即每輪訓練使用的批大小),shuffle表示是否打亂數據順序(對於整個datasets里包含的所有數據)。
對於batch_size和shuffle都是根據業務需求來認為指定的,不做過多說明。
對於Datasets對象來說,我們可以根據自己的數據類型來自定義,自己定義一個類,繼承Datasets類。
二、分析Datasets類
class Dataset(object): """An abstract class representing a Dataset. All other datasets should subclass it. All subclasses should override ``__len__``, that provides the size of the dataset, and ``__getitem__``, supporting integer indexing in range from 0 to len(self) exclusive. """ def __getitem__(self, index): raise NotImplementedError def __len__(self): raise NotImplementedError def __add__(self, other): return ConcatDataset([self, other])
上述代碼是pytorch中Datasets的源碼,注意成員方法__getitem__和__len__都是未實現的。我們要實現自定義Datasets類來完成數據的讀取,則只需要完成這兩個成員方法的重寫。
首先,__getitem__()方法用來從datasets中讀取一條數據,這條數據包含訓練圖片(已CV距離)和標簽,參數index表示圖片和標簽在總數據集中的Index。
__len__()方法返回數據集的總長度(訓練集的總數)。
三、簡單實現MyDatasets類
# -*- coding:utf-8 -*- __author__ = 'Leo.Z' import os from torch.utils.data import Dataset from torch.utils.data import DataLoader import matplotlib.image as mpimg # 對所有圖片生成path-label map.txt def generate_map(root_dir): current_path = os.path.abspath(__file__) father_path = os.path.abspath(os.path.dirname(current_path) + os.path.sep + ".") with open(root_dir + 'map.txt', 'w') as wfp: for idx in range(10): subdir = os.path.join(root_dir, '%d/' % idx) for file_name in os.listdir(subdir): abs_name = os.path.join(father_path, subdir, file_name) linux_abs_name = abs_name.replace("\\", '/') wfp.write('{file_dir} {label}\n'.format(file_dir=linux_abs_name, label=idx)) # 實現MyDatasets類 class MyDatasets(Dataset): def __init__(self, dir): # 獲取數據存放的dir # 例如d:/images/ self.data_dir = dir # 用於存放(image,label) tuple的list,存放的數據例如(d:/image/1.png,4) self.image_target_list = [] # 從dir--label的map文件中將所有的tuple對讀取到image_target_list中 # map.txt中全部存放的是d:/.../image_data/1/3.jpg 1 路徑最好是絕對路徑 with open(os.path.join(dir, 'map.txt'), 'r') as fp: content = fp.readlines() str_list = [s.rstrip().split() for s in content] # 將所有圖片的dir--label對都放入列表,如果要執行多個epoch,可以在這里多復制幾遍,然后統一shuffle比較好 self.image_target_list = [(x[0], int(x[1])) for x in str_list] def __getitem__(self, index): image_label_pair = self.image_target_list[index] # 按path讀取圖片數據,並轉換為圖片格式例如[3,32,32] img = mpimg.imread(image_label_pair[0]) return img, image_label_pair[1] def __len__(self): return len(self.image_target_list) if __name__ == '__main__': # 生成map.txt # generate_map('train/') train_loader = DataLoader(MyDatasets('train/'), batch_size=128, shuffle=True) for step in range(20000): for idx, (img, label) in enumerate(train_loader): print(img.shape) print(label.shape)
上述代碼簡要說明了利用Datasets類和DataLoader類來讀取數據,本例用的是圖片原始數據,大概的結構如下:
如果使用其他形式的數據,例如二進制文件,則需要字節讀取文件,分割成每一張圖片和label,然后從__getitem__中返回就可以了。例如cifar-10數據,我們只需要在__getitem__方法中,按index來讀取對應位置的字節,然后轉換為label和img,並返回。在__len__中返回cifar-10訓練集的總樣本數。DataLoader就可以根據我們提供的index,len以及batch_size,shuffle來返回相應的batch數據和label。