PyTorch 數據集類 和 數據加載類 的一些嘗試


最近在學習PyTorch,  但是對里面的數據類和數據加載類比較迷糊,可能是封裝的太好大部分情況下是不需要有什么自己的操作的,不過偶然遇到一些自己導入的數據時就會遇到一些問題,因此自己對此做了一些小實驗,小嘗試。

 

 

下面給出一個常用的數據類使用方式:

def data_tf(x):
    x = np.array(x, dtype='float32') / 255 # 將數據變到 0 ~ 1 之間
    x = (x - 0.5) / 0.5 # 標准化,這個技巧之后會講到
    x = x.reshape((-1,)) # 拉平
    x = torch.from_numpy(x)
    return x



from torchvision.datasets import MNIST # 導入 pytorch 內置的 mnist 數據
train_set = MNIST('./data', train=True, transform=data_tf, download=True) # 載入數據集,申明定義的數據變換
test_set = MNIST('./data', train=False, transform=data_tf, download=True)

 

 

其中,  data_tf  並不是必須要有的,比如:

from torchvision.datasets import MNIST # 導入 pytorch 內置的 mnist 數據
train_set = MNIST('./data', train=True, download=True) # 載入數據集,申明定義的數據變換
test_set = MNIST('./data', train=False, download=True)

這里面的MNIST類是框架自帶的,可以自動下載MNIST數據庫,   ./data  是指將下載的數據集存放在當前目錄下的哪個目錄下,    train 這個屬性 True時 則在 ./data文件夾下面在建立一個 train的文件夾然后把下載的數據存放在其中,  當train屬性是False的時候則把下載的數據放在 test文件夾下面。   

划線部分是老版本的PyTorch的處理方式,  最近試了一下最新版本  PyTorch 1.0   ,   train為True的時候是把數據放在  ./data/processed  文件夾下面, 命名為training.pt  ,  為False 的時候則放在  ./data/processed  文件夾下面, 命名為test.pt  。

 

 

 

 

 

 

這時候就出現了一個問題, 如果你使用的數據集不是框架自帶的那么如何使用數據類呢,這個時候就要使用  pytorch 中的  Dataset 類了。

from torch.utils.data import Dataset

我們需要重寫   Dataset類, 需要實現的方法為  __len__   和   __getitem__    這兩個內置方法,  這里可以看出其思想就是要重寫的類需要支持按照索引查找的方法。

 

 

 

 

這里我們還是舉個例子:

 

 

 

 

從這個例子可以看出  mydataset就是我們自定義的 myDataset 類生成的自定義數據類對象。我們可以在myDataset類中自定義一些方法來對需要的數據進行處理。

為說明該問題另附加一個例子:

from torch.utils.data import Dataset


#需要在pytorch中使用的數據
data=[[1.1, 1.2, 1.3], [2.1, 2.2, 2.3], [3.1, 3.2, 3.3], [4.1, 4.2, 4.3], [5.1, 5.2, 5.3]]


class myDataset(Dataset):
    def __init__(self, indata):
        self.data=indata
    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):
        return self.data[idx]


mydataset=myDataset(data)

 

 

那么又來了一個問題,我們不重寫 Dataset類的話可不可以呢, 經過嘗試發現還真可以,如下:

 

 

 

 

又如:

 

 

 

由這個例子可以看出數據類對象可以不重寫Dataset類, 只要具備  __len__      __getitem__    方法就可以。而且從這個例子我們可以看出  DataLoader  是一個迭代器, 如果shuffle 設置為 True 那么在每次迭代之前都會重新排序。

同時由上面兩個例子可以看出  DataLoader類會把傳入的數據集合中的數據轉化為  torch.tensor 類型, 當然是采用默認的  DataLoader類中轉化函數 transform的情況下。

這也就是說  DataLoader 默認的轉化函數 transform操作為    傳入的[ [x, x, x], [y, y, y] ] 輸出的是 [ tensor([x, x, x]),  tensor([y, y, y]) ] ,

傳入的是  tensor([ [x, x, x], [y, y, y] ]) 輸出的是 tensor([ tensor([x, x, x]),  tensor([y, y, y]) ] ),   (這個例子是在   batch_size=2 的情況)。

 

 

 

綜上,可知  其實   Dataset類, 和 DataLoader類其實在pytorch 計算過程中都不是一定要有的,  其中Dataset類是起一個規范作用,意義在於要人們對不同的類型數據做一些初步的調整,使其支持按照索引讀取,以使其可以在 DataLoader中使用。

DataLoader 是一個迭代器, 可以方便的通過設置 batch_size 來實現 batch過程,transform則是對數據的一些處理。

 

 

 

 

---------------------------------------------------------------------------------------------------

 

上述內容更正:

 

import torch
import numpy as np
from torch.utils.data import Dataset
from torch.utils.data import DataLoader


#需要在pytorch中使用的數據
data=[[1.1, 1.2, 1.3], [2.1, 2.2, 2.3], [3.1, 3.2, 3.3], [4.1, 4.2, 4.3], [5.1, 5.2, 5.3]]

class myDataset(Dataset):
    def __init__(self, indata):
        self.data=indata
    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):
        return self.data[idx]


mydataset=myDataset(data)
train_data=DataLoader(mydataset, batch_size=3, shuffle=True)

print("上文的錯誤操作:")

for i in train_data:
    print(i)
    print('-'*30)
print('again')
for i in train_data:
    print(i)
    print('-'*30)


#########################################


data=np.array(data)
data=torch.from_numpy(data)


mydataset=myDataset(data)
train_data=DataLoader(mydataset, batch_size=3, shuffle=True)


print("修正后的正確操作:")

for i in train_data:
    print(i)
    print('-'*30)
print('again')
for i in train_data:
    print(i)
    print('-'*30)

 

 

 

(base) devil@devilmaycry:/tmp$ python w.py 
上文的錯誤操作:
[tensor([3.1000, 4.1000, 5.1000], dtype=torch.float64), tensor([3.2000, 4.2000, 5.2000], dtype=torch.float64), tensor([3.3000, 4.3000, 5.3000], dtype=torch.float64)]
------------------------------
[tensor([1.1000, 2.1000], dtype=torch.float64), tensor([1.2000, 2.2000], dtype=torch.float64), tensor([1.3000, 2.3000], dtype=torch.float64)]
------------------------------
again
[tensor([3.1000, 5.1000, 1.1000], dtype=torch.float64), tensor([3.2000, 5.2000, 1.2000], dtype=torch.float64), tensor([3.3000, 5.3000, 1.3000], dtype=torch.float64)]
------------------------------
[tensor([2.1000, 4.1000], dtype=torch.float64), tensor([2.2000, 4.2000], dtype=torch.float64), tensor([2.3000, 4.3000], dtype=torch.float64)]


------------------------------

修正后的正確操作: tensor([[
2.1000, 2.2000, 2.3000], [1.1000, 1.2000, 1.3000], [3.1000, 3.2000, 3.3000]], dtype=torch.float64) ------------------------------ tensor([[4.1000, 4.2000, 4.3000], [5.1000, 5.2000, 5.3000]], dtype=torch.float64) ------------------------------ again tensor([[5.1000, 5.2000, 5.3000], [4.1000, 4.2000, 4.3000], [3.1000, 3.2000, 3.3000]], dtype=torch.float64) ------------------------------ tensor([[2.1000, 2.2000, 2.3000], [1.1000, 1.2000, 1.3000]], dtype=torch.float64) ------------------------------

 

可以看出  傳入到   Dataset  中的對象必須是  torch  類型的 tensor  類型, 如果傳入的是list則會得出錯誤結果。

 

 

 

-----------------------------------------------------------------------------------------------------

 

 

補充:

之所以發現上面的這個錯誤,是因為發現了下面的代碼:

import numpy as np
from torchvision.datasets import mnist # 導入 pytorch 內置的 mnist 數據
from torch.utils.data import DataLoader
#from torch.utils.data import Dataset


def data_tf(x):
    x = np.array(x, dtype='float32') / 255
    x = (x - 0.5) / 0.5 # 數據預處理,標准化
    x = x.reshape((-1,)) # 拉平
    x = torch.from_numpy(x)
    return x


#Dataset
# 重新載入數據集,申明定義的數據變換
train_set = mnist.MNIST('./data', train=True, transform=data_tf, download=True)
test_set = mnist.MNIST('./data', train=False, transform=data_tf, download=True)


train_data = DataLoader(train_set, batch_size=64, shuffle=True)
test_data = DataLoader(test_set, batch_size=128, shuffle=False)

 

從上面的   data_tf  函數中我們發現,  Dataset對象返回的是   torch 的  tensor 對象。

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM