pytorch使用自定義數據集
DataLoader是pytorch提供的,一般我們要寫的是Dataset,也就是DataLoader中的一個參數,其基本框架是:
class CustomDataset(data.Dataset):#需要繼承data.Dataset
def __init__(self):
# TODO
# 1. Initialize file path or list of file names.
pass
def __getitem__(self, index):
# TODO
# 1. Read one data from file (e.g. using numpy.fromfile, PIL.Image.open).
# 2. Preprocess the data (e.g. torchvision.Transform).
# 3. Return a data pair (e.g. image and label).
#這里需要注意的是,第一步:read one data,是一個data
pass
def __len__(self):
# You should change 0 to the total size of your dataset.
return 0
由此可見,需要暴露的API只有__getitem__
和__len__
,還有一個構造函數