Pytorch基礎(5)——批數據訓練


一、知識點:

  • 相關包:torch.utils.data

import torch import torch.utils.data as Data
  • 包裝數據類:TensorDataset

【包裝數據和目標張量的數據集,通過沿着第一個維度索引兩個張量來】

class torch.utils.data.TensorDataset(data_tensor, target_tensor)
#data_tensor (Tensor) - 包含樣本數據
#target_tensor (Tensor) - 包含樣本目標(標簽)

 

  • 加載數據類:DataLoader

【數據加載器。組合數據集和采樣器,並在數據集上提供單進程或多進程迭代器。】

class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False)
#num_workers (int, optional) – 用多少個子進程加載數據
#drop_last (bool, optional) – 如果數據集大小不能被batch size整除,則設置為True后可刪除最后一個不完整的batch。如果設為False並且數據集的大小不能被batch size整除,則最后一個batch將更小。(默認: False)

 

二、利用torch.utils.data進行批數據訓練:

導入包:

import torch import torch.utils.data as Data

設置參數並創建數據:

Batch_size = 5 x = torch.linspace(1,10,10) y = torch.linspace(10,1,10)

將數據包裝到TensorDataset中:

torch_dataset = Data.TensorDataset(x , y)

加載數據:

loader = Data.DataLoader( dataset = torch_dataset, batch_size = Batch_size, shuffle=True, num_workers = 2,  #采用兩個進程來提取
)

epoch 3次,每次epoch的訓練步數steps = 2【batch_size = 5,總數據量為10】:

若最后不夠一個batch_size,就只拿剩下的。

for epoch in range(3): for step , (batch_x,batch_y) in enumerate(loader): #training……
        print('epoch:',epoch, '| step:',step, '| batch_x:',batch_x.numpy(), '| batch_y:',batch_y.numpy() )

 

結果:

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM