[深度学习]-Dataset数据集加载


加载数据集dataloader

from torch.utils.data import DataLoader
form 自己写的dataset import Dataset

train_set = Dataset(train=True)
val_set = Dataset(train=False)

image_datasets = {
    'train': train_set, 'val': val_set
}

batch_size = 4

dataloaders = {
    'train': DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=2),
    'val': DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=2)
}

dataset_sizes = {
    x: len(image_datasets[x]) for x in image_datasets.keys()
}
print(dataset_sizes)

for epoch in range(num_epochs):
	for phase in ['train', 'val']:
		if phase == 'train':
			# for param_group in optimizer.param_groups:
				# print("LR", param_group['lr'])
			model.train() 
		else:
			model.eval() 

以上适用于train一遍test一遍的情况

或者分别加载训练和测试:

train_dataset = Dataset('train')
train_data_loader = torch.utils.data.DataLoader(train_dataset, batch_size=8, shuffle=True,
                                                    num_workers=2, collate_fn=collate_fn)

test_dataset = Dataset('eval')
test_data_loader = torch.utils.data.DataLoader(test_dataset, batch_size=8, shuffle=False,
                                                   num_workers=2, collate_fn=collate_fn)

自己写Dataset

from torch.utils.data import Dataset
import os
import cv2
import torch
import numpy as np


class Dataset(Dataset):
    def __init__(self,train):
        if train:
            self.datapath = {'image': '/home/myy/code/Final_Project/data_train.txt', 'target':'/home/myy/code/Final_Project/gt_train.txt'}
        else:
            self.datapath = {'image': '/home/myy/code/Final_Project/data_test.txt', 'target':'/home/myy/code/Final_Project/gt_test.txt'}
            # self.datapath = {'image': '/home/myy/code/Final_Project/test_small_data.txt', 'target':'/home/myy/code/Final_Project/test_small.txt'}
        self.image_list, self.target_list = self.read_txt(self.datapath)
    
# 此处可以依据需要自己定义一些函数
# 注意调用前要加上`self.`
# 比如以下两个读取数据的函数,read_txt、read_json就是自己定义的
    def read_txt(self,datapath):
        im =[]
        target_image = []
        print(datapath)
        with open(datapath['image'], 'r') as f:
            image_list = f.readlines()
        with open(datapath['target'], 'r') as f:
            target_list = f.readlines()
        return image_list, target_list

    def read_json(save_path, encoding='utf8'):
        jsondata = []
        with open(save_path, 'r', encoding=encoding) as f:
            content = f.read()
            content = json.loads(content)
            for key in content:
                jsondata.append(content[key])
            return jsondata

    def __getitem__(self, item):
        # 最核心的部分,经过处理,要返回输入和gt

        return img, target

    def __len__(self):
		# 这可以根据具体情况修改,不写也行
        return len(self.data)


免责声明!

本站转载的文章为个人学习借鉴使用,本站对版权不负任何法律责任。如果侵犯了您的隐私权益,请联系本站邮箱yoyou2525@163.com删除。



 
粤ICP备18138465号  © 2018-2025 CODEPRJ.COM