-
https://pytorch.org/tutorials/beginner/data_loading_tutorial.html
-
num_works設置過高出錯(多線程錯誤,使用gpu就沒事了)
# -*- coding: utf-8 -*-
"""
Created on Mon Aug 3 23:30:39 2020
@author: Administrator
"""
import torch # 導入模塊
import torch.utils.data as Data
BATCH_SIZE = 8 # 每一批的數據量
x=torch.linspace(1,10,10) # 定義X為 1 到 10 等距離大小的數
y=torch.linspace(10,1,10)
# 轉換成torch能識別的Dataset
# 這個可以自定義DataSet:https://www.cnblogs.com/douzujun/p/13429912.html
torch_dataset = Data.TensorDataset(x, y) # 將數據放入 torch_dataset
loader=Data.DataLoader(
dataset=torch_dataset, # 將數據放入loader
batch_size=BATCH_SIZE, # 每個數據段大小為 BATCH_SIZE=5
shuffle=True , # 是否打亂數據的排布
num_workers=0 # 使用多進程加載的進程數,0代表不使用多進程
)
for epoch in range(3):
for step, (batch_x,batch_y) in enumerate(loader):
print('epoch',epoch,'|step:',step," | batch_x",batch_x.numpy(),
'|batch_y:',batch_y.numpy())
epoch 0 |step: 0 | batch_x [ 7. 3. 1. 8. 10. 9. 5. 4.] |batch_y: [ 4. 8. 10. 3. 1. 2. 6. 7.]
epoch 0 |step: 1 | batch_x [2. 6.] |batch_y: [9. 5.]
epoch 1 |step: 0 | batch_x [ 6. 7. 5. 4. 1. 10. 2. 9.] |batch_y: [ 5. 4. 6. 7. 10. 1. 9. 2.]
epoch 1 |step: 1 | batch_x [3. 8.] |batch_y: [8. 3.]
epoch 2 |step: 0 | batch_x [ 4. 5. 7. 1. 6. 9. 10. 3.] |batch_y: [ 7. 6. 4. 10. 5. 2. 1. 8.]
epoch 2 |step: 1 | batch_x [8. 2.] |batch_y: [3. 9.]
DataLoader的函數定義如下:
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
num_workers=0, collate_fn=default_collate, pin_memory=False,
drop_last=False)
-
dataset:加載的數據集(Dataset對象)
-
batch_size:batch size
-
shuffle::是否將數據打亂
-
sampler: 樣本抽樣,后續會詳細介紹
-
num_workers:使用多進程加載的進程數,0代表不使用多進程
-
collate_fn: 如何將多個樣本數據拼接成一個batch,一般使用默認的拼接方式即可
-
pin_memory:是否將數據保存在pin memory區,pin memory中的數據轉到GPU會快一些
-
drop_last:dataset中的數據個數可能不是batch_size的整數倍,drop_last為True會將多出來不足一個batch的數據丟棄