pytorch數據讀取機制:
sampler生成索引index,根據索引從DataSet中獲取圖片和標簽
1.torch.utils.data.DataLoader
功能:構建可迭代的數據裝在器
dataset:Dataset類,決定數據從哪讀取及如何讀取
batchsize:批大小
num_works:是否多進程讀取數據,當條件允許時,多進程讀取數據會加快數據讀取速度。
shuffle:每個epoch是否亂序
drop_last:當樣本數不能被batchsize整除時,是否舍棄最后一批數據
DataLoader(dataset, batchsize=1, shuffle=False, batch_sampler=None, num_workers=0, collate_fn=None, pin_memeory=False, drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None)
epoch:所有訓練樣本都已輸入到模型中,稱為一個epoch
iteration:一批樣本輸入到模型中,稱為一個iteration
batchsize:批大小,決定一個epoch有多少個iteration
例如:
樣本總數:80, batchsize:8
1epoch = 10 iteraion
樣本總數:87, batchsize:8
1 epoch = 10 iteration drop_last=True
1 epoch = 11 iteration drop_last=False
2.torch.utils.data.Dataset
功能:Dataset抽象類,所有自定義的Dataset需要繼承它,並且復寫
__getitem__()
getitem:接收一個索引,返回一個樣本
class Dataset(object): def __getitem__(self, index): raise NotImplementedError def __add__(self, other): return ConcatDataset([self, other])
人命幣分類實例:
數據分割:
import os import random import shutil def makedir(new_dir): if not os.path.exists(new_dir): os.makedirs(new_dir) if __name__ == '__main__': random.seed(1) dataset_dir = os.path.join("..", "..", "data", "RMB_data") split_dir = os.path.join("..", "..", "data", "rmb_split") train_dir = os.path.join(split_dir, "train") valid_dir = os.path.join(split_dir, "valid") test_dir = os.path.join(split_dir, "test") train_pct = 0.8 valid_pct = 0.1 test_pct = 0.1 for root, dirs, files in os.walk(dataset_dir): for sub_dir in dirs: imgs = os.listdir(os.path.join(root, sub_dir)) imgs = list(filter(lambda x: x.endswith('.jpg'), imgs)) random.shuffle(imgs) img_count = len(imgs) train_point = int(img_count * train_pct) valid_point = int(img_count * (train_pct + valid_pct)) for i in range(img_count): if i < train_point: out_dir = os.path.join(train_dir, sub_dir) elif i < valid_point: out_dir = os.path.join(valid_dir, sub_dir) else: out_dir = os.path.join(test_dir, sub_dir) makedir(out_dir) target_path = os.path.join(out_dir, imgs[i]) src_path = os.path.join(dataset_dir, sub_dir, imgs[i]) shutil.copy(src_path, target_path) print('Class:{}, train:{}, valid:{}, test:{}'.format(sub_dir, train_point, valid_point-train_point, img_count-valid_point))
創建Dataset
import os import random from PIL import Image from torch.utils.data import Dataset random.seed(1) rmb_label = {"1": 0, "100": 1} class RMBDataset(Dataset): def __init__(self, data_dir, transform=None): """ rmb面額分類任務的Dataset :param data_dir: str, 數據集所在路徑 :param transform: torch.transform,數據預處理 """ self.label_name = {"1": 0, "100": 1} self.data_info = self.get_img_info(data_dir) # data_info存儲所有圖片路徑和標簽,在DataLoader中通過index讀取樣本 self.transform = transform def __getitem__(self, index): path_img, label = self.data_info[index] img = Image.open(path_img).convert('RGB') # 0~255 if self.transform is not None: img = self.transform(img) # 在這里做transform,轉為tensor等等 return img, label def __len__(self): return len(self.data_info) @staticmethod def get_img_info(data_dir): data_info = list() for root, dirs, _ in os.walk(data_dir): # 遍歷類別 for sub_dir in dirs: img_names = os.listdir(os.path.join(root, sub_dir)) img_names = list(filter(lambda x: x.endswith('.jpg'), img_names)) # 遍歷圖片 for i in range(len(img_names)): img_name = img_names[i] path_img = os.path.join(root, sub_dir, img_name) label = rmb_label[sub_dir] data_info.append((path_img, int(label))) return data_info
3.transforms
torch.transforms:常用圖像處理方法
數據中心化 數據標准化 縮放 裁剪 旋轉 翻轉 填充 噪聲添加 灰度轉換 線性變換 仿射變換 亮度、飽和度及對比度