pytorch使用自定義數據集


pytorch使用自定義數據集

DataLoader是pytorch提供的,一般我們要寫的是Dataset,也就是DataLoader中的一個參數,其基本框架是:

class CustomDataset(data.Dataset):#需要繼承data.Dataset
    def __init__(self):
        # TODO
        # 1. Initialize file path or list of file names.
        pass
    def __getitem__(self, index):
        # TODO
        # 1. Read one data from file (e.g. using numpy.fromfile, PIL.Image.open).
        # 2. Preprocess the data (e.g. torchvision.Transform).
        # 3. Return a data pair (e.g. image and label).
        #這里需要注意的是,第一步:read one data,是一個data
        pass
    def __len__(self):
        # You should change 0 to the total size of your dataset.
        return 0

由此可見,需要暴露的API只有__getitem____len__,還有一個構造函數


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM