TensorDataset和DataLoader的用法


TensorDataset

pytorch中TensorDateset是處理數據的工具包,其作用是將數據進行打包,例如訓練數據X和數據對應的Label,將兩者打包為一一對應的關系,即X中的一個數據對應Label中的一個值(X的一行數據對應Label中的一行數據)

(1)首先引入工具包

1 import torch
2 from torch.utils.data import TensorDataset
3 from torch.utils.data import DataLoader

(2)給定數據

1 a = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [1, 2, 3], [4, 5, 6], [7, 8, 9], [1, 2, 3], [4, 5, 6], [7, 8, 9], [1, 2, 3], [4, 5, 6], [7, 8, 9]])
2 b = torch.tensor([44, 55, 66, 77, 88, 99, 100, 110, 120, 130, 140, 150])

(3)將數據打包

1 train_data = TensorDataset(a,b) 

(4)使用enumerate()方法遍歷數據,enumerate()方法是將列表或者元組或者字符串的數據給定索引,具有映射的效果。

1 for id, data in enumerate(train_data):
2     print("id:", id)
3     print(data[0])
4     print(data[1])
id: 0
tensor([1, 2, 3])
tensor(44)
id: 1
tensor([4, 5, 6])
tensor(55)
id: 2
tensor([7, 8, 9])
tensor(66)
id: 3
tensor([1, 2, 3])
tensor(77)
id: 4
tensor([4, 5, 6])
tensor(88)
id: 5
tensor([7, 8, 9])
tensor(99)
id: 6
tensor([1, 2, 3])
tensor(100)
id: 7
tensor([4, 5, 6])
tensor(110)
......
1 train_loader = DataLoader(train_data, batch_size=4, shuffle=True)  #使用數據迭代器
2 for id, data in enumerate(train_loader):
3     x_data, lable = data
4     print('batch:{0}, x_data:{1}, lable:{2}'.format(id, x_data, lable))

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM