【源碼解讀】cycleGAN(三):數據讀取


源碼地址:https://github.com/aitorzip/PyTorch-CycleGAN

數據的讀取是比較簡單的,cycleGAN對數據沒有pair的需求,不同域的兩個數據集分別存放於A,B兩個文件夾,寫好dataset接口即可

fake_A_buffer = ReplayBuffer()
fake_B_buffer = ReplayBuffer()

# Dataset loader
transforms_ = [ transforms.Resize(int(opt.size*1.12), Image.BICUBIC), 
                transforms.RandomCrop(opt.size), 
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)) ]
dataloader = DataLoader(ImageDataset(opt.dataroot, transforms_=transforms_, unaligned=True), 
                        batch_size=opt.batchSize, shuffle=True, num_workers=opt.n_cpu)

上面的代碼中,首先定義好buffer(后面細說),然后定義好圖像變換,調用定義好的ImageDataset(繼承自dataset) 對象,即可從dataloader中讀取數據。下面是ImageDataset的代碼

class ImageDataset(Dataset):
    def __init__(self, root, transforms_=None, unaligned=False, mode='train'):
        self.transform = transforms.Compose(transforms_)
        self.unaligned = unaligned

        self.files_A = sorted(glob.glob(os.path.join(root, '%s/A' % mode) + '/*.*'))
        self.files_B = sorted(glob.glob(os.path.join(root, '%s/B' % mode) + '/*.*'))

    def __getitem__(self, index):
        item_A = self.transform(Image.open(self.files_A[index % len(self.files_A)]))

        if self.unaligned:
            item_B = self.transform(Image.open(self.files_B[random.randint(0, len(self.files_B) - 1)]))
        else:
            item_B = self.transform(Image.open(self.files_B[index % len(self.files_B)]))

        return {'A': item_A, 'B': item_B}

    def __len__(self):
        return max(len(self.files_A), len(self.files_B))

標准的實現了__init__, __getitem__, __len__三個接口,不過我還不太清楚這里對數據進行排序和對齊的目的,對齊可以按序讀取,不對齊則隨機讀取最后,關於buffer,參考cycleGAN的論文,原話是這么說的“Second, to reduce model oscillation [15], we follow Shrivastava et al.’s strategy [46] and update the discriminators using a history of generated images rather than the ones produced by the latest generators. We keep an image buffer that stores the 50 previously created images

也就是說,是為了訓練的穩定,采用歷史生成的虛假樣本來更新判別器,而不是當前生成的虛假樣本,至於原理,參考的是另一篇論文。我們來看一下代碼

class ReplayBuffer():
    def __init__(self, max_size=50):
        assert (max_size > 0), 'Empty buffer or trying to create a black hole. Be careful.'
        self.max_size = max_size
        self.data = []

    def push_and_pop(self, data):
        to_return = []
        for element in data.data:
            element = torch.unsqueeze(element, 0)
            if len(self.data) < self.max_size:
                self.data.append(element)
                to_return.append(element)
            else:
                if random.uniform(0,1) > 0.5:
                    i = random.randint(0, self.max_size-1)
                    to_return.append(self.data[i].clone())
                    self.data[i] = element
                else:
                    to_return.append(element)
        return Variable(torch.cat(to_return))

定義了一個buffer對象,有一個數據存儲表data,大小預設為50,我認為它的運轉流程是這樣的:數據表未填滿時,每次讀取的都是當前生成的虛假圖像,當數據表填滿時,隨機決定 1. 在數據表中隨機抽取一批數據,返回,並且用當前數據補充進來 2. 采用當前數據

至於為什么這樣有道理,要看參考論文了

 

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM