pytorch自定義dataset


參考

一個例子

import torch
from torch.utils import data

class MyDataset(data.Dataset):
    def __init__(self):
        super(MyDataset, self).__init__()
        self.data = torch.randn(8,2)
    
    def __getitem__(self, index):
        return self.data[index], index
    
    def __len__(self):
        return self.data.size()[0]

data_set = MyDataset()
print(data_set.data)

輸出
tensor([[-1.3907, -0.0916],
[-0.4626, -1.3323],
[ 1.4242, -2.1718],
[ 1.5850, 0.3320],
[-1.0804, 0.3884],
[ 0.6567, -0.1234],
[ 1.6721, -0.7327],
[-1.9595, -0.3512]])

data_loader = data.DataLoader(data_set, 
                              batch_size=4,
                              shuffle=False)
print(len(data_set))
for i, (number, labels) in enumerate(data_loader):
    print(number)

輸出
8
tensor([[-1.3907, -0.0916],
[-0.4626, -1.3323],
[ 1.4242, -2.1718],
[ 1.5850, 0.3320]])
tensor([0, 1, 2, 3])
tensor([[-1.0804, 0.3884],
[ 0.6567, -0.1234],
[ 1.6721, -0.7327],
[-1.9595, -0.3512]])
tensor([4, 5, 6, 7])


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM