在我的torchvision庫里介紹的博文(https://www.cnblogs.com/yjphhw/p/9773333.html)里說了對pytorch的dataset的定義方式。
本文相當於實現一個自定義的數據集,而這正是我們在做自己工程所需要的,我們總是用自己的數據嘛。
繼承 from torch.utils.data import Dataset 類
然后實現 __len__(self) ,和 __getitem__(self,idx) 兩個方法。以及數據增強也可以寫入,數據增強想了想還是放到了Dataset里,
習慣上可能與常用的不同,但是覺得由於每種數據都有自己的增強方法所以,增強方法可以和數據集綁定到一起的。
接上一節我們通過切割,獲取了2217個圖像切片。
這就是我的FarmDataset
from torch.utils.data import Dataset, DataLoader from PIL import Image,ImageEnhance from osgeo import gdal from torchvision import transforms import glob import torch as tc import numpy as np class FarmDataset(Dataset): def __init__(self,istrain=True,isaug=True): self.istrain=istrain self.trainxformat='./data/train/data1500/*.bmp' self.trainyformat='./data/train/label1500/*.bmp' self.testxformat='./data/test/*.png' self.fns=glob.glob(self.trainxformat) if istrain else glob.glob(self.testxformat) self.length=len(self.fns) self.transforms=transforms self.isaug=isaug def __len__(self): #total length is 2217 return self.length def __getitem__(self,idx): if self.istrain: imgxname=self.fns[idx] sampleimg = Image.open(imgxname) imgyname=imgxname.replace('data1500','label1500') targetimg = Image.open(imgyname).convert('L') #sampleimg.save('original.bmp') #data augmentation if self.isaug: sampleimg,targetimg=self.imgtrans(sampleimg,targetimg) #check the result of dataautmentation #sampleimg.save('sampletmp.bmp') #targetimg.save('targettmp.bmp') sampleimg=transforms.ToTensor()(sampleimg) #targetimg=transforms.ToTensor()(targetimg).squeeze(0).long() targetimg=np.array(targetimg) targetimg=tc.from_numpy(targetimg).long() #to tensor #print(sampleimg.shape,targetimg.shape) return sampleimg,targetimg else: return gdal.Open(self.fns[idx]) def imgtrans(self,x,y,outsize=1024): '''input is a PIL image image dataaugumentation return also aPIL image。 ''' #rotate should consider y degree=np.random.randint(360) x=x.rotate(degree,resample=Image.NEAREST,fillcolor=0) y=y.rotate(degree,resample=Image.NEAREST,fillcolor=0) #here should be carefull, in case of label damage #random do the input image augmentation if np.random.random()>0.5: #sharpness factor=0.5+np.random.random() enhancer=ImageEnhance.Sharpness(x) x=enhancer.enhance(factor) if np.random.random()>0.5: #color augument factor=0.5+np.random.random() enhancer=ImageEnhance.Color(x) x=enhancer.enhance(factor) if np.random.random()>0.5: #contrast augument factor=0.5+np.random.random() enhancer=ImageEnhance.Contrast(x) x=enhancer.enhance(factor) if np.random.random()>0.5: #brightness factor=0.5+np.random.random() enhancer=ImageEnhance.Brightness(x) x=enhancer.enhance(factor) #img flip transtypes=[Image.FLIP_LEFT_RIGHT,Image.FLIP_TOP_BOTTOM, Image.ROTATE_90,Image.ROTATE_180,Image.ROTATE_270] transtype=transtypes[np.random.randint(len(transtypes))] x = x.transpose(transtype) y = y.transpose(transtype) #img resize between 0.8-1.2 w,h=x.size factor=1+np.random.normal()/5 if factor>1.2: factor=1.2 if factor<0.8: factor=0.8 #print(factor,x.size) x=x.resize((int(w*factor),int(h*factor)),Image.NEAREST) y=y.resize((int(w*factor),int(h*factor)),Image.NEAREST) #random crop w,h=x.size stx=np.random.randint(w-outsize) sty=np.random.randint(h-outsize) #print((stx,sty,outsize,outsize)) x=x.crop((stx,sty,stx+outsize,sty+outsize)) #stx,sty,width,height y=y.crop((stx,sty,stx+outsize,sty+outsize)) #print(x.size,y.size) return x,y #return outsized pil image if __name__=='__main__': d=FarmDataset(istrain=True) x,y=d[2216] print(x.shape) print(y.shape)
輸入的是個1500x1500的圖像,輸出的是增強后的1024x1024后的圖像。
其實對於分割問題來看,以后這個就可以作為一個模板,修改修改就可以換到另一個數據集中。
放幾張圖片:
原始圖像:
進行數據增強后可以得到的一系列:
經過check 發現沒有的問題通過測試。