TensorDataset
pytorch中TensorDateset是处理数据的工具包,其作用是将数据进行打包,例如训练数据X和数据对应的Label,将两者打包为一一对应的关系,即X中的一个数据对应Label中的一个值(X的一行数据对应Label中的一行数据)
(1)首先引入工具包
1 import torch 2 from torch.utils.data import TensorDataset 3 from torch.utils.data import DataLoader
(2)给定数据
1 a = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [1, 2, 3], [4, 5, 6], [7, 8, 9], [1, 2, 3], [4, 5, 6], [7, 8, 9], [1, 2, 3], [4, 5, 6], [7, 8, 9]]) 2 b = torch.tensor([44, 55, 66, 77, 88, 99, 100, 110, 120, 130, 140, 150])
(3)将数据打包
1 train_data = TensorDataset(a,b)
(4)使用enumerate()方法遍历数据,enumerate()方法是将列表或者元组或者字符串的数据给定索引,具有映射的效果。
1 for id, data in enumerate(train_data): 2 print("id:", id) 3 print(data[0]) 4 print(data[1])
id: 0 tensor([1, 2, 3]) tensor(44) id: 1 tensor([4, 5, 6]) tensor(55) id: 2 tensor([7, 8, 9]) tensor(66) id: 3 tensor([1, 2, 3]) tensor(77) id: 4 tensor([4, 5, 6]) tensor(88) id: 5 tensor([7, 8, 9]) tensor(99) id: 6 tensor([1, 2, 3]) tensor(100) id: 7 tensor([4, 5, 6]) tensor(110)
......
1 train_loader = DataLoader(train_data, batch_size=4, shuffle=True) #使用数据迭代器 2 for id, data in enumerate(train_loader): 3 x_data, lable = data 4 print('batch:{0}, x_data:{1}, lable:{2}'.format(id, x_data, lable))