pytorch中DataSet和DataLoader的使用詳解(Subset,ConcatDataset)


1. 首先導入需要用到的包

from torch.utils.data import DataLoader,Dataset

2. 自定義Dataset

一般情況下我們使用Dataset,需要自定義一個類來繼承Dataset,然后實現__getitem__()方法和__len__()方法
使用示例如下所示:

import torch
a = [[1,2,3,4],[4,5,6,7,9],[6,7,8,9,4,5],[4,3,2],[8,7,5,4],[4,8,7,1]]
b = [1,2,3,4,5,6]

class mydataset(Dataset):
    def __init__(self,x,y):
        self.feature = x
        self.label = y
    
    def __getitem__(self,item):
        return torch.tensor(self.feature[item]),self.label[item]   #根據需要進行設置

    def __len__(self):
        return len(self.feature)

dataset = mydataset(a,b)

print(dataset[0])

程序運行結果如下所示:

(tensor([1, 2, 3, 4]), 1)

3. 創建DataLoader

DataLoader需要傳入幾個參數,先看一下官方的定義:

常用到的幾個參數解釋如下:

# dataset:數據集,傳入我們剛才創建的數據集即可;
# batch_size:每個batch的大小
# collate_fn:按照定義函數的方式進行取數據
# shuffle:是否將數據集中的數據進行打亂

使用示例如下所示:

def fun(x):                                                    # 根據自己的需求定義dataloader返回數據的格式
    x.sort(key=lambda data:len(data[0]),reverse=True)
    # print(x)
    feature = []
    label = []
    length = []
    for i in x:
        feature.append(i[0])
        label.append(i[1])
        length.append(len(i[0]))
    # feature = pad_sequence(feature,batch_first=True,padding_value=-1)     # 可以適當的進行補齊操作
    return feature,label,length


dataloader = DataLoader(dataset,batch_size=2,collate_fn=fun)    # 定義DataLoader

for x,y,length in dataloader:
    print(x,y,length)
    print('------------------')

程序運行結果如下所示:

[tensor([4, 5, 6, 7, 9]), tensor([1, 2, 3, 4])] [2, 1] [5, 4]
------------------
[tensor([6, 7, 8, 9, 4, 5]), tensor([4, 3, 2])] [3, 4] [6, 3]
------------------
[tensor([8, 7, 5, 4]), tensor([4, 8, 7, 1])] [5, 6] [4, 4]
  1. Subset的使用
    首先看一下官網的定義:

    該類的用處是從一個大的數據集中取出一部分作為數據子集,其中indices是索引值,可以是列表形式。
    5.ConcatDataset的使用
    官網的定義如下:

    該類的用處是將多個數據子集合並為一個整的數據集,其中參數datasets是需要合並的數據子集集合,以列表的形式給出。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM