上文介紹了數據讀取、數據轉換、批量處理等等。了解到在PyTorch中,數據加載主要有兩種方式:
- 1. 自定義的數據集對象。數據集對象被抽象為
Dataset
類,實現自定義的數據集需要繼承Dataset。且須實現__len__()和__getitem__()兩個方法。 - 2. 利用torchvision包。torchvision已經預先實現了常用的Dataset,包括前面使用過的CIFAR-10,以及ImageNet、COCO、MNIST、LSUN等數據集,可通過諸如
torchvision.datasets.CIFAR10
來調用。這里介紹ImageFolder
,其也繼承自Dataset。ImageFolder
假設所有的文件按文件夾保存,每個文件夾下存儲同一個類別的圖片,文件夾名為類名,其構造函數如下:
ImageFolder(root, transform=None, target_transform=None, loader=default_loader)
它主要有四個參數:
-
root
:在root指定的路徑下尋找圖片transform
:對PIL Image進行的轉換操作,transform的輸入是使用loader讀取圖片的返回對象target_transform
:對label的轉換loader
:給定路徑后如何讀取圖片,默認讀取為RGB格式的PIL Image對象
label是按照文件夾名順序排序后存成字典,即{類名:類序號(從0開始)},一般來說最好直接將文件夾命名為從0開始的數字,這樣會和ImageFolder實際的label一致,如果不是這種命名規 范,建議看看self.class_to_idx
屬性以了解label和文件夾名的映射關系。
下面我們進一步理解數據讀取的內容:
1. 查看dataset實例都有哪些方法與成員:
from torchvision.datasets import ImageFolder dataset = ImageFolder('data/ants&bee_2/', transform=transform)
打印一下dataset類的成員與方法:
print(dir(dataset))
['__add__', '__class__', '__delattr__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__getitem__', '__gt__', '__hash__', '__init__', '__le__', '__len__', '__lt__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__sizeof__', '__str__', '__subclasshook__', '__weakref__', 'class_to_idx', 'classes', 'imgs', 'loader', 'root', 'target_transform', 'transform']
dataset.__len__() : 數據集的數目
dataset.__getitem__(idx) : 輸入索引,返回對應的圖片與標簽
dataste.class_to_idx : 字典,類與標簽 eg:{‘ants’:0, ‘bees’: 1}
dataset.classes: 列表,返回類別 eg:[ 'ants', 'bees' ]
dataset.imgs : 列表,返回所有圖片的路徑和對應的label
特別的對於dataset,可以根據dataset.__getitem__(idx)來返回第idx張圖與標簽,還可以直接進行索引:
dataset[0][1] # 第一維是第幾張圖,第二維為1返回label
dataset[0][0] # 為0返回圖片數據
還可以循環迭代:
for img, label in dataset: print(img.size(), label)
所以無論是自定義的dataset,或是 ImageFolder得到的dataset,因其都繼承自utils.data.Dataset, 故以上方法兩種方法都有。
2. Dataloader使用
如果只是每次讀取一張圖,那么上面的操作已經足夠了,但是為了批量操作、打散數據、多進程處理、定制batch,那么我們還需要更高級的類:
DataLoader定義如下:
class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False)
- dataset:加載的數據集(Dataset對象)
- batch_size:batch size
- shuffle::是否將數據打亂
- sampler: 樣本抽樣。定義從數據集中提取樣本的策略。如果指定,則忽略shuffle參數。
- batch_sampler(sampler,可選) - 和sampler一樣,但一次返回一批索引。與batch_size,shuffle,sampler和drop_last相互排斥。
- num_workers:使用多進程加載的進程數,0代表不使用多進程
- collate_fn: 如何將多個樣本數據拼接成一個batch,一般使用默認的拼接方式即可
- pin_memory:是否將數據保存在pin memory區,pin memory中的數據轉到GPU會快一些,默認為false
- drop_last:dataset中的數據個數可能不是batch_size的整數倍,drop_last為True會將多出來不足一個batch的數據丟棄,默認為false
from torch.utils.data import DataLoader dataloader = DataLoader(dataset, batch_size=3, shuffle=True, num_workers=0, drop_last=False)
dataiter = iter(dataloader) # 迭代器
imgs, labels = next(dataiter) imgs.size() # batch_size, channel, height, weight # torch.Size([3, 3, 224, 224])
下面主要介紹collate_fn和sampler
的用法:
1)collate_fn
在數據處理中,有時會出現某個樣本無法讀取等問題,比如某張圖片損壞。這時在__getitem__
函數中將出現異常,此時最好的解決方案即是將出錯的樣本剔除。如果實在是遇到這種情況無法處理,則可以返回None對象,然后在Dataloader
中實現自定義的collate_fn
,將空對象過濾掉。但要注意,在這種情況下dataloader返回的batch數目會少於batch_size。
eg:
class NewDogCat(DogCat): # 繼承前面實現的DogCat數據集
def __getitem__(self, index): try: # 調用父類的獲取函數,即 DogCat.__getitem__(self, index)
return super(NewDogCat,self).__getitem__(index) except: return None, None from torch.utils.data.dataloader import default_collate # 導入默認的拼接方式
def my_collate_fn(batch): ''' batch中每個元素形如(data, label) '''
# 過濾為None的數據
batch = list(filter(lambda x:x[0] is not None, batch)) return default_collate(batch) # 用默認方式拼接過濾后的batch數據
dataset = NewDogCat('data/dogcat_wrong/', transforms=transform) dataset[5] # (None, None)
第5張圖壞掉了所以返回None,下面查看對於批量讀取怎么處理:
dataloader = DataLoader(dataset, 2, collate_fn=my_collate_fn, num_workers=1) # 批量為2 for batch_datas, batch_labels in dataloader: print(batch_datas.size(),batch_labels.size())
可以看到第三個批量只有1張圖,因為第5張圖壞掉了,所以第三個批量只有第六張圖。第五個批量也只有1張圖是因為數據集總共只有9張圖(含壞圖)。如果設置drop_last為true,那么第五個批量就被丟棄了。對於諸如樣本損壞或數據集加載異常等情況,還可以通過其它方式解決。例如但凡遇到異常情況,就隨機取一張圖片代替:
class NewDogCat(DogCat): def __getitem__(self, index): try: return super(NewDogCat, self).__getitem__(index) except: new_index = random.randint(0, len(self)-1) return self[new_index]
相比較丟棄異常圖片而言,這種做法會更好一些,因為它能保證每個batch的數目仍是batch_size。但在大多數情況下,最好的方式還是對數據進行徹底清洗。
2)sampler
sampler
模塊用來對數據進行采樣。常用的有隨機采樣器:RandomSampler
,當dataloader的shuffle
參數為True時,系統會自動調用這個采樣器,實現打亂數據。默認的是采用SequentialSampler
,它會按順序一個一個進行采樣。這里介紹另外一個很有用的采樣方法: WeightedRandomSampler
,它會根據每個樣本的權重選取數據,在樣本比例不均衡的問題中,可用它來進行重采樣。
class torch.utils.data.sampler.WeightedRandomSampler(weights, num_samples, replacement=True)
構建WeightedRandomSampler
時需提供兩個參數:每個樣本的權重weights
、共選取的樣本總數num_samples
,以及一個可選參數replacement
。權重越大的樣本被選中的概率越大,待選取的樣本數目一般小於全部的樣本數目。replacement
用於指定是否可以重復選取某一個樣本,默認為True,即允許在一個epoch中重復采樣某一個數據。如果設為False,則當某一類的樣本被全部選取完,但其樣本數目仍未達到num_samples時,sampler將不會再從該類中選擇數據,此時可能導致weights
參數失效。下面舉例說明。
dataset = DogCat('data/dogcat/', transforms=transform) # 狗的圖片被取出的概率是貓的概率的兩倍 # 兩類圖片被取出的概率與weights的絕對大小無關,只和比值有關
weights = [2 if label == 1 else 1 for data, label in dataset] weights # [2, 2, 1, 1, 1, 1, 2, 2]
from torch.utils.data.sampler import WeightedRandomSampler sampler = WeightedRandomSampler(weights,\ num_samples=9,\ replacement=True) dataloader = DataLoader(dataset, batch_size=3, sampler=sampler) for datas, labels in dataloader: print(labels.tolist())
[1, 0, 1] [1, 0, 1] [1, 1, 0]
可見貓狗樣本比例約為1:2,另外一共只有8個樣本,但是卻返回了9個,說明肯定有被重復返回的,這就是replacement參數的作用,下面將replacement設為False試試:
sampler = WeightedRandomSampler(weights, 8, replacement=False) dataloader = DataLoader(dataset, batch_size=4, sampler=sampler) for datas, labels in dataloader: print(labels.tolist())
在這種情況下,num_samples等於dataset的樣本總數,為了不重復選取,sampler會將每個樣本都返回,這樣就失去weight參數的意義了。
從上面的例子可見sampler在樣本采樣中的作用:如果指定了sampler,shuffle將不再生效,並且sampler.num_samples會覆蓋dataset的實際大小,即一個epoch返回的圖片總數取決於sampler.num_samples。
部分轉載自:pytorch-book-master