Pytorch dataset自定義【直播】2019 年縣域農業大腦AI挑戰賽---數據准備(二),Dataset定義


在我的torchvision庫里介紹的博文(https://www.cnblogs.com/yjphhw/p/9773333.html)里說了對pytorch的dataset的定義方式。

本文相當於實現一個自定義的數據集,而這正是我們在做自己工程所需要的,我們總是用自己的數據嘛。

繼承 from torch.utils.data import Dataset 類

然后實現 __len__(self) ,和 __getitem__(self,idx) 兩個方法。以及數據增強也可以寫入,數據增強想了想還是放到了Dataset里,

習慣上可能與常用的不同,但是覺得由於每種數據都有自己的增強方法所以,增強方法可以和數據集綁定到一起的。

接上一節我們通過切割,獲取了2217個圖像切片。

這就是我的FarmDataset

from torch.utils.data import Dataset, DataLoader
from PIL import Image,ImageEnhance
from osgeo import gdal
from torchvision import transforms
import glob
import torch as tc 
import numpy as np


class FarmDataset(Dataset):
    def __init__(self,istrain=True,isaug=True):
        self.istrain=istrain
        self.trainxformat='./data/train/data1500/*.bmp'
        self.trainyformat='./data/train/label1500/*.bmp'
        self.testxformat='./data/test/*.png'
        self.fns=glob.glob(self.trainxformat) if istrain else glob.glob(self.testxformat)
        self.length=len(self.fns)
        self.transforms=transforms
        self.isaug=isaug
        
    def __len__(self):
        #total length is 2217 
        return self.length
    def __getitem__(self,idx):
        if self.istrain:
            
            imgxname=self.fns[idx]
            sampleimg = Image.open(imgxname)
            imgyname=imgxname.replace('data1500','label1500')
            targetimg = Image.open(imgyname).convert('L')
            #sampleimg.save('original.bmp')
            
            #data augmentation
            if self.isaug:
                sampleimg,targetimg=self.imgtrans(sampleimg,targetimg)
            
            #check the result of dataautmentation
            #sampleimg.save('sampletmp.bmp')
            #targetimg.save('targettmp.bmp')
            
            sampleimg=transforms.ToTensor()(sampleimg) 
            #targetimg=transforms.ToTensor()(targetimg).squeeze(0).long() 
            targetimg=np.array(targetimg)
            targetimg=tc.from_numpy(targetimg).long()         #to tensor
            #print(sampleimg.shape,targetimg.shape)
            return sampleimg,targetimg
        else:
            return gdal.Open(self.fns[idx])
    def imgtrans(self,x,y,outsize=1024):
        '''input is a PIL image 
           image dataaugumentation
           return also aPIL image。
        '''
        #rotate should consider y
        degree=np.random.randint(360)
        x=x.rotate(degree,resample=Image.NEAREST,fillcolor=0)
        y=y.rotate(degree,resample=Image.NEAREST,fillcolor=0)  #here should be carefull, in case of label damage
        
        #random do the input image augmentation
        if np.random.random()>0.5:
            #sharpness 
            factor=0.5+np.random.random()
            enhancer=ImageEnhance.Sharpness(x)
            x=enhancer.enhance(factor)
        if np.random.random()>0.5:
            #color augument
            factor=0.5+np.random.random()
            enhancer=ImageEnhance.Color(x)
            x=enhancer.enhance(factor)
        if np.random.random()>0.5:
            #contrast augument
            factor=0.5+np.random.random()
            enhancer=ImageEnhance.Contrast(x)
            x=enhancer.enhance(factor)
        if np.random.random()>0.5:
            #brightness
            factor=0.5+np.random.random()
            enhancer=ImageEnhance.Brightness(x)
            x=enhancer.enhance(factor)
        
        #img flip
        transtypes=[Image.FLIP_LEFT_RIGHT,Image.FLIP_TOP_BOTTOM,
                Image.ROTATE_90,Image.ROTATE_180,Image.ROTATE_270]
        transtype=transtypes[np.random.randint(len(transtypes))]
        x = x.transpose(transtype)
        y = y.transpose(transtype)
        
        #img resize between 0.8-1.2
        w,h=x.size
        factor=1+np.random.normal()/5
        if factor>1.2: factor=1.2
        if factor<0.8: factor=0.8
        #print(factor,x.size)
        x=x.resize((int(w*factor),int(h*factor)),Image.NEAREST)
        y=y.resize((int(w*factor),int(h*factor)),Image.NEAREST)
        
        #random crop
        w,h=x.size
        stx=np.random.randint(w-outsize)
        sty=np.random.randint(h-outsize)
        #print((stx,sty,outsize,outsize))
        x=x.crop((stx,sty,stx+outsize,sty+outsize)) #stx,sty,width,height
        y=y.crop((stx,sty,stx+outsize,sty+outsize))
        #print(x.size,y.size)
        return x,y   #return outsized pil image
    

if __name__=='__main__':
    d=FarmDataset(istrain=True)
    x,y=d[2216]
    print(x.shape)
    print(y.shape)

  

 

  輸入的是個1500x1500的圖像,輸出的是增強后的1024x1024后的圖像。

  其實對於分割問題來看,以后這個就可以作為一個模板,修改修改就可以換到另一個數據集中。

放幾張圖片:

原始圖像:

進行數據增強后可以得到的一系列:

經過check 發現沒有的問題通過測試。

 

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM