代碼:
#進行批訓練 import torch import torch.utils.data as Data BATCH_SIZE = 5 #每批5個數據 if __name__ == '__main__': x = torch.linspace(1, 10, 10) #x是從1到10共10個數據 y = torch.linspace(10, 1, 10) #y是從10到1共10個數據 #torch_dataset = Data.TensorDataset(data_tensor = x, target_tensor=y)會報錯 torch_dataset = Data.TensorDataset(x,y) loader = Data.DataLoader( #使我們的訓練變成一小批一小批的 dataset = torch_dataset, #將所有數據放入dataset中 batch_size= BATCH_SIZE, shuffle=True, #true訓練的時候隨機打亂數據,false不打亂 num_workers=2, #每次訓練用兩個線程或進程進行提取 ) for epoch in range(3): for step, (batch_x, batch_y) in enumerate(loader): #利用enumerate可以同時獲得索引(step)和值 print('Epoch:', epoch, '| Step:', step, '| batch_x:', batch_x.numpy(), '| batch_y:', batch_y.numpy())
過程中遇到了問題,問題及解決辦法都在https://blog.csdn.net/thunderf/article/details/94733747