TensorDataset和DataLoader的用法


TensorDataset

pytorch中TensorDateset是处理数据的工具包,其作用是将数据进行打包,例如训练数据X和数据对应的Label,将两者打包为一一对应的关系,即X中的一个数据对应Label中的一个值(X的一行数据对应Label中的一行数据)

(1)首先引入工具包

1 import torch
2 from torch.utils.data import TensorDataset
3 from torch.utils.data import DataLoader

(2)给定数据

1 a = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [1, 2, 3], [4, 5, 6], [7, 8, 9], [1, 2, 3], [4, 5, 6], [7, 8, 9], [1, 2, 3], [4, 5, 6], [7, 8, 9]])
2 b = torch.tensor([44, 55, 66, 77, 88, 99, 100, 110, 120, 130, 140, 150])

(3)将数据打包

1 train_data = TensorDataset(a,b) 

(4)使用enumerate()方法遍历数据,enumerate()方法是将列表或者元组或者字符串的数据给定索引,具有映射的效果。

1 for id, data in enumerate(train_data):
2     print("id:", id)
3     print(data[0])
4     print(data[1])
id: 0
tensor([1, 2, 3])
tensor(44)
id: 1
tensor([4, 5, 6])
tensor(55)
id: 2
tensor([7, 8, 9])
tensor(66)
id: 3
tensor([1, 2, 3])
tensor(77)
id: 4
tensor([4, 5, 6])
tensor(88)
id: 5
tensor([7, 8, 9])
tensor(99)
id: 6
tensor([1, 2, 3])
tensor(100)
id: 7
tensor([4, 5, 6])
tensor(110)
......
1 train_loader = DataLoader(train_data, batch_size=4, shuffle=True)  #使用数据迭代器
2 for id, data in enumerate(train_loader):
3     x_data, lable = data
4     print('batch:{0}, x_data:{1}, lable:{2}'.format(id, x_data, lable))

 


免责声明!

本站转载的文章为个人学习借鉴使用,本站对版权不负任何法律责任。如果侵犯了您的隐私权益,请联系本站邮箱yoyou2525@163.com删除。



 
粤ICP备18138465号  © 2018-2025 CODEPRJ.COM