pytorch SubsetRandomSampler 用法和說明


官網:https://pytorch.org/docs/stable/data.html?highlight=subsetrandomsampler#torch.utils.data.SubsetRandomSampler

推薦參考:https://www.sohu.com/a/291959747_197042

https://www.jianshu.com/p/a32ae0294223

https://www.cnblogs.com/marsggbo/p/10496696.html

 

理解一下:

DataLoader其實就是先根據sampler方法先采樣,再切分出batch(比如樣本有10個,SubsetRandomSampler返回一個下標,比如0到7,那么取出這8個數據,然后按照batch_size切分出一個個的batch)

實際應用:

from torch.utils.data import DataLoader
from torch.utils.data import sampler

train_data = CriteoDataset('./data', train=True) #自己定義 split_num = int(len(train_data) * 0.8) index_list = list(range(len(train_data))) train_idx, valid_idx = index_list[:split_num], index_list[split_num:] tr_sampler = sampler.SubsetRandomSampler(train_idx) val_sampler = sampler.SubsetRandomSampler(valid_idx) loader_train = DataLoader(train_data, batch_size=100, sampler=tr_sampler) loader_val = DataLoader(val_data, batch_size=100, sampler=val_sampler)

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM