TensorDataset
pytorch中TensorDateset是處理數據的工具包,其作用是將數據進行打包,例如訓練數據X和數據對應的Label,將兩者打包為一一對應的關系,即X中的一個數據對應Label中的一個值(X的一行數據對應Label中的一行數據)
(1)首先引入工具包
1 import torch 2 from torch.utils.data import TensorDataset 3 from torch.utils.data import DataLoader
(2)給定數據
1 a = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [1, 2, 3], [4, 5, 6], [7, 8, 9], [1, 2, 3], [4, 5, 6], [7, 8, 9], [1, 2, 3], [4, 5, 6], [7, 8, 9]]) 2 b = torch.tensor([44, 55, 66, 77, 88, 99, 100, 110, 120, 130, 140, 150])
(3)將數據打包
1 train_data = TensorDataset(a,b)
(4)使用enumerate()方法遍歷數據,enumerate()方法是將列表或者元組或者字符串的數據給定索引,具有映射的效果。
1 for id, data in enumerate(train_data): 2 print("id:", id) 3 print(data[0]) 4 print(data[1])
id: 0 tensor([1, 2, 3]) tensor(44) id: 1 tensor([4, 5, 6]) tensor(55) id: 2 tensor([7, 8, 9]) tensor(66) id: 3 tensor([1, 2, 3]) tensor(77) id: 4 tensor([4, 5, 6]) tensor(88) id: 5 tensor([7, 8, 9]) tensor(99) id: 6 tensor([1, 2, 3]) tensor(100) id: 7 tensor([4, 5, 6]) tensor(110)
......
1 train_loader = DataLoader(train_data, batch_size=4, shuffle=True) #使用數據迭代器 2 for id, data in enumerate(train_loader): 3 x_data, lable = data 4 print('batch:{0}, x_data:{1}, lable:{2}'.format(id, x_data, lable))