Pytorch tutorial 之Datar Loading and Processing (2)


上文介紹了數據讀取、數據轉換、批量處理等等。了解到在PyTorch中,數據加載主要有兩種方式:

  • 1. 自定義的數據集對象。數據集對象被抽象為Dataset類,實現自定義的數據集需要繼承Dataset。且須實現__len__()和__getitem__()兩個方法。
  • 2. 利用torchvision包torchvision已經預先實現了常用的Dataset,包括前面使用過的CIFAR-10,以及ImageNet、COCO、MNIST、LSUN等數據集,可通過諸如torchvision.datasets.CIFAR10來調用。這里介紹ImageFolder,其也繼承自DatasetImageFolder假設所有的文件按文件夾保存,每個文件夾下存儲同一個類別的圖片,文件夾名為類名,其構造函數如下:
 ImageFolder(root, transform=None, target_transform=None, loader=default_loader) 

         它主要有四個參數:

    • root:在root指定的路徑下尋找圖片
    • transform:對PIL Image進行的轉換操作,transform的輸入是使用loader讀取圖片的返回對象
    • target_transform:對label的轉換
    • loader:給定路徑后如何讀取圖片,默認讀取為RGB格式的PIL Image對象

         label是按照文件夾名順序排序后存成字典,即{類名:類序號(從0開始)},一般來說最好直接將文件夾命名為從0開始的數字,這樣會和ImageFolder實際的label一致,如果不是這種命名規 范,建議看看self.class_to_idx屬性以了解label和文件夾名的映射關系。

 

下面我們進一步理解數據讀取的內容:

1.  查看dataset實例都有哪些方法與成員:

from torchvision.datasets import ImageFolder dataset = ImageFolder('data/ants&bee_2/', transform=transform)

打印一下dataset類的成員與方法:

print(dir(dataset))

['__add__', '__class__', '__delattr__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__getitem__', '__gt__', '__hash__', '__init__', '__le__', '__len__', '__lt__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__sizeof__', '__str__', '__subclasshook__', '__weakref__', 'class_to_idx', 'classes', 'imgs', 'loader', 'root', 'target_transform', 'transform']

dataset.__len__()  : 數據集的數目

dataset.__getitem__(idx) : 輸入索引,返回對應的圖片與標簽

dataste.class_to_idx : 字典,類與標簽  eg:{‘ants’:0, ‘bees’: 1}

dataset.classes: 列表,返回類別  eg:[ 'ants', 'bees' ]

dataset.imgs : 列表,返回所有圖片的路徑和對應的label

特別的對於dataset,可以根據dataset.__getitem__(idx)來返回第idx張圖與標簽,還可以直接進行索引:

dataset[0][1]    # 第一維是第幾張圖,第二維為1返回label
dataset[0][0]    # 為0返回圖片數據

還可以循環迭代:

for img, label in dataset: print(img.size(), label)

所以無論是自定義的dataset,或是 ImageFolder得到的dataset,因其都繼承自utils.data.Dataset, 故以上方法兩種方法都有

 

2.  Dataloader使用

如果只是每次讀取一張圖,那么上面的操作已經足夠了,但是為了批量操作、打散數據、多進程處理、定制batch,那么我們還需要更高級的類:

DataLoader定義如下:

class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False)

  • dataset:加載的數據集(Dataset對象)
  • batch_size:batch size
  • shuffle::是否將數據打亂
  • sampler: 樣本抽樣。定義從數據集中提取樣本的策略。如果指定,則忽略shuffle參數。
  • batch_sampler(sampler,可選) - 和sampler一樣,但一次返回一批索引。與batch_size,shuffle,sampler和drop_last相互排斥。
  • num_workers:使用多進程加載的進程數,0代表不使用多進程
  • collate_fn: 如何將多個樣本數據拼接成一個batch,一般使用默認的拼接方式即可
  • pin_memory:是否將數據保存在pin memory區,pin memory中的數據轉到GPU會快一些,默認為false
  • drop_last:dataset中的數據個數可能不是batch_size的整數倍,drop_last為True會將多出來不足一個batch的數據丟棄,默認為false
from torch.utils.data import DataLoader dataloader = DataLoader(dataset, batch_size=3, shuffle=True, num_workers=0, drop_last=False)
dataiter
= iter(dataloader) # 迭代器
imgs, labels
= next(dataiter) imgs.size() # batch_size, channel, height, weight #
torch.Size([3, 3, 224, 224])

下面主要介紹collate_fn和sampler的用法:

1)collate_fn

在數據處理中,有時會出現某個樣本無法讀取等問題,比如某張圖片損壞。這時在__getitem__函數中將出現異常,此時最好的解決方案即是將出錯的樣本剔除。如果實在是遇到這種情況無法處理,則可以返回None對象,然后在Dataloader中實現自定義的collate_fn,將空對象過濾掉。但要注意,在這種情況下dataloader返回的batch數目會少於batch_size。

eg:

class NewDogCat(DogCat): # 繼承前面實現的DogCat數據集
    def __getitem__(self, index): try: # 調用父類的獲取函數,即 DogCat.__getitem__(self, index)
            return super(NewDogCat,self).__getitem__(index) except: return None, None from torch.utils.data.dataloader import default_collate # 導入默認的拼接方式
def my_collate_fn(batch): ''' batch中每個元素形如(data, label) '''
    # 過濾為None的數據
    batch = list(filter(lambda x:x[0] is not None, batch)) return default_collate(batch) # 用默認方式拼接過濾后的batch數據
dataset = NewDogCat('data/dogcat_wrong/', transforms=transform) dataset[5]    # (None, None)

第5張圖壞掉了所以返回None,下面查看對於批量讀取怎么處理:

dataloader = DataLoader(dataset, 2, collate_fn=my_collate_fn, num_workers=1) # 批量為2 for batch_datas, batch_labels in dataloader: print(batch_datas.size(),batch_labels.size())
torch.Size([2, 3, 224, 224]) torch.Size([2])
torch.Size([2, 3, 224, 224]) torch.Size([2])
torch.Size([1, 3, 224, 224]) torch.Size([1])
torch.Size([2, 3, 224, 224]) torch.Size([2])
torch.Size([1, 3, 224, 224]) torch.Size([1])

可以看到第三個批量只有1張圖,因為第5張圖壞掉了,所以第三個批量只有第六張圖。第五個批量也只有1張圖是因為數據集總共只有9張圖(含壞圖)。如果設置drop_last為true,那么第五個批量就被丟棄了對於諸如樣本損壞或數據集加載異常等情況,還可以通過其它方式解決。例如但凡遇到異常情況,就隨機取一張圖片代替:

class NewDogCat(DogCat): def __getitem__(self, index): try: return super(NewDogCat, self).__getitem__(index) except: new_index = random.randint(0, len(self)-1) return self[new_index]

 相比較丟棄異常圖片而言,這種做法會更好一些,因為它能保證每個batch的數目仍是batch_size。但在大多數情況下,最好的方式還是對數據進行徹底清洗。

 2)sampler

        sampler模塊用來對數據進行采樣。常用的有隨機采樣器:RandomSampler,當dataloader的shuffle參數為True時,系統會自動調用這個采樣器,實現打亂數據。默認的是采用SequentialSampler,它會按順序一個一個進行采樣。這里介紹另外一個很有用的采樣方法: WeightedRandomSampler,它會根據每個樣本的權重選取數據,在樣本比例不均衡的問題中,可用它來進行重采樣。

       class torch.utils.data.sampler.WeightedRandomSampler(weights, num_samples, replacement=True)

       構建WeightedRandomSampler時需提供兩個參數:每個樣本的權重weights、共選取的樣本總數num_samples,以及一個可選參數replacement。權重越大的樣本被選中的概率越大,待選取的樣本數目一般小於全部的樣本數目。replacement用於指定是否可以重復選取某一個樣本,默認為True,即允許在一個epoch中重復采樣某一個數據。如果設為False,則當某一類的樣本被全部選取完,但其樣本數目仍未達到num_samples時,sampler將不會再從該類中選擇數據,此時可能導致weights參數失效。下面舉例說明。

dataset = DogCat('data/dogcat/', transforms=transform) # 狗的圖片被取出的概率是貓的概率的兩倍 # 兩類圖片被取出的概率與weights的絕對大小無關,只和比值有關
weights = [2 if label == 1 else 1 for data, label in dataset] weights # [2, 2, 1, 1, 1, 1, 2, 2]
from torch.utils.data.sampler import WeightedRandomSampler sampler = WeightedRandomSampler(weights,\ num_samples=9,\ replacement=True) dataloader = DataLoader(dataset, batch_size=3, sampler=sampler) for datas, labels in dataloader: print(labels.tolist())
[1, 0, 1]
[1, 0, 1]
[1, 1, 0]

可見貓狗樣本比例約為1:2,另外一共只有8個樣本,但是卻返回了9個,說明肯定有被重復返回的,這就是replacement參數的作用,下面將replacement設為False試試:

sampler = WeightedRandomSampler(weights, 8, replacement=False) dataloader = DataLoader(dataset, batch_size=4, sampler=sampler) for datas, labels in dataloader: print(labels.tolist())
[1, 0, 1, 0]
[1, 1, 0, 0]

在這種情況下,num_samples等於dataset的樣本總數,為了不重復選取,sampler會將每個樣本都返回,這樣就失去weight參數的意義了。

從上面的例子可見sampler在樣本采樣中的作用:如果指定了sampler,shuffle將不再生效,並且sampler.num_samples會覆蓋dataset的實際大小,即一個epoch返回的圖片總數取決於sampler.num_samples。

 

部分轉載自:pytorch-book-master


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM