pytorch 加載數據集


pytorch初學者,想加載自己的數據,了解了一下數據類型、維度等信息,方便以后加載其他數據。

1 torchvision.transforms實現數據預處理

transforms.Totensor()操作必須要有,將數據轉為張量格式。

2 torch.utils.data.Dataset實現數據讀取

要使用自己的數據集,需要構建Dataset子類,定義子類為MyDataset,在MyDataset的init函數中定義path_dict變量,來獲取不同類型的數據的路徑。

定義子類MyDataset時,必須要重載兩個函數 getitem 和 len,

__getitem__:實現數據集的下標索引,返回對應的數據及標簽;

__len__:返回數據集的大小。

 

設加載的數據集大小為L;

定義MyDataset實例:my_datasets = MyDataset(data_dir, transform = data_transform) 。

my_datasets 由L個tuple組成,len(my_datasets) = L;

每個tuple長度為2:0:tensor 樣本(Channel,Height,Width)

                               1:int 標簽 

 

   

3 torch.utils.data.DataLoader實現數據集加載

torch.utils.data.DataLoader()合成數據並提供迭代訪問,由兩部分組成:

—dataset(Dataset):輸入要加載的數據,就是上面的my_datasets;

—batch_size,shuffle,sampler,batch_sampler,num_workers,collate_fn, drop_last,timeout,worker_init_fn等參數。

其中:batch_size:批尺寸,默認為1; 

      shuffle:是否在每個epoch開始隨機打亂數據,默認為False;

 

設data_loader長度為 l ;

加載數據:data_loader = DataLoader(my_datasets, batch_size = BATCH_SIZE,  shuffle = True) 

data_loader 由 l 個 tuple組成,l = len(data_loader) = len(my_datasets) / batch_size;

迭代訪問:

 

e 長度為2:0:int  step 表示第幾個batch

                   1:list(長度為2)表示一個batch包含的所有樣本和標簽   

                                  0:tensor  樣本(Batch_size,Channel,Height,Width)

                                  1:tensor  標簽    Batch_size     

 

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM