pytorch實現批訓練


代碼:

#進行批訓練
import torch
import torch.utils.data as Data

BATCH_SIZE = 5  #每批5個數據

if __name__ == '__main__':
    x = torch.linspace(1, 10, 10)  #x是從1到10共10個數據
    y = torch.linspace(10, 1, 10)  #y是從10到1共10個數據

    #torch_dataset = Data.TensorDataset(data_tensor = x, target_tensor=y)會報錯
    torch_dataset = Data.TensorDataset(x,y)
    loader = Data.DataLoader(      #使我們的訓練變成一小批一小批的
        dataset = torch_dataset,   #將所有數據放入dataset中
        batch_size= BATCH_SIZE,
        shuffle=True,              #true訓練的時候隨機打亂數據,false不打亂
        num_workers=2,             #每次訓練用兩個線程或進程進行提取
    )   

    for epoch in range(3):
        for step, (batch_x, batch_y) in enumerate(loader):  #利用enumerate可以同時獲得索引(step)和值
            print('Epoch:', epoch, '| Step:', step, '| batch_x:', 
            batch_x.numpy(), '| batch_y:', batch_y.numpy())

過程中遇到了問題,問題及解決辦法都在https://blog.csdn.net/thunderf/article/details/94733747


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM