Pytorch划分數據集的方法


之前用過sklearn提供的划分數據集的函數,覺得超級方便。但是在使用TensorFlow和Pytorch的時候一直找不到類似的功能,之前搜索的關鍵字都是“pytorch split dataset”之類的,但是搜出來還是沒有我想要的。結果今天見鬼了突然看見了這么一個函數torch.utils.data.Subset。我的天,為什么超級開心hhhh。終於不用每次都手動划分數據集了。

torch.utils.data

Pytorch提供的對數據集進行操作的函數詳見:https://pytorch.org/docs/master/data.html#torch.utils.data.SubsetRandomSampler

torch的這個文件包含了一些關於數據集處理的類:

  • class torch.utils.data.Dataset: 一個抽象類, 所有其他類的數據集類都應該是它的子類。而且其子類必須重載兩個重要的函數:len(提供數據集的大小)、getitem(支持整數索引)。
  • class torch.utils.data.TensorDataset: 封裝成tensor的數據集,每一個樣本都通過索引張量來獲得。
  • class torch.utils.data.ConcatDataset: 連接不同的數據集以構成更大的新數據集。
  • class torch.utils.data.Subset(dataset, indices): 獲取指定一個索引序列對應的子數據集。
  • class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None): 數據加載器。組合了一個數據集和采樣器,並提供關於數據的迭代器。
  • torch.utils.data.random_split(dataset, lengths): 按照給定的長度將數據集划分成沒有重疊的新數據集組合。
  • class torch.utils.data.Sampler(data_source):所有采樣的器的基類。每個采樣器子類都需要提供 iter 方-法以方便迭代器進行索引 和一個 len方法 以方便返回迭代器的長度。
  • class torch.utils.data.SequentialSampler(data_source):順序采樣樣本,始終按照同一個順序。
  • class torch.utils.data.RandomSampler(data_source):無放回地隨機采樣樣本元素。
  • class torch.utils.data.SubsetRandomSampler(indices):無放回地按照給定的索引列表采樣樣本元素。
  • class torch.utils.data.WeightedRandomSampler(weights, num_samples, replacement=True): 按照給定的概率來采樣樣本。
  • class torch.utils.data.BatchSampler(sampler, batch_size, drop_last): 在一個batch中封裝一個其他的采樣器。
  • class torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=None, rank=None):采樣器可以約束數據加載進數據集的子集。

示例

下面Pytorch提供的划分數據集的方法以示例的方式給出:

SubsetRandomSampler

...

dataset = MyCustomDataset(my_path)
batch_size = 16
validation_split = .2
shuffle_dataset = True
random_seed= 42

# Creating data indices for training and validation splits:
dataset_size = len(dataset)
indices = list(range(dataset_size))
split = int(np.floor(validation_split * dataset_size))
if shuffle_dataset :
    np.random.seed(random_seed)
    np.random.shuffle(indices)
train_indices, val_indices = indices[split:], indices[:split]

# Creating PT data samplers and loaders:
train_sampler = SubsetRandomSampler(train_indices)
valid_sampler = SubsetRandomSampler(val_indices)

train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, 
                                           sampler=train_sampler)
validation_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                                sampler=valid_sampler)

# Usage Example:
num_epochs = 10
for epoch in range(num_epochs):
    # Train:   
    for batch_index, (faces, labels) in enumerate(train_loader):
        # ...

random_split

...

train_size = int(0.8 * len(full_dataset))
test_size = len(full_dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(full_dataset, [train_size, test_size])

參考:




微信公眾號:AutoML機器學習
MARSGGBO原創
如有意合作或學術討論歡迎私戳聯系~
郵箱:marsggbo@foxmail.com

2019-3-8




免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM