對抗生成網絡(GAN)中損失函數的理解


對抗生成網絡(GAN)中損失函數的理解

最近開始接觸對抗生產網絡,目地是用GAN生成一些假樣本,去解決樣本不平衡的問題。

看了兩天GAN的代碼,沒有太多特別的地方,因為之前看論文的時候就已經知道大體的結構。但是唯一沒有搞清除的就是:生成器和判別器的損失函數,以及損失函數是怎么向后傳播,去更新權重的。

簡述一些GAN的訓練過程:

1、先定義一個標簽:valid = 1,fake = 0。當然這兩個值的維度是按照數據的輸出來看的。再定義了兩個優化器。用於生成器和判別器。

2、隨機生成一個噪聲z。將z作為生成器的輸入,輸出gen_imgs(假樣本)。

3、計算生成器的損失:(這里就很有趣了)

定義:生成器的損失為g_loss。損失函數為adverisal_loss()。判別器為discriminator()。

g_loss = adverisal_loss(discriminator(gen_imgs), valid)
g_loss.backward()
optimizer_G.step()

可以看出來,g_loss是根據一個輸出(將生成的樣本作為輸入的判別器的輸出)與valid的一個損失。這是個什么意思?這里為什么扯進來了判別器?

用白話來一步一步解釋:

<1> discriminator(gen_imgs) 的輸出是個什么?

既然是判別器,意思就是判別gen_imgs是不是真樣本。如果是用softmax輸出,是一個概率,為真樣本的概率。

<2> g_loss = adverisal_loss(discriminator(gen_imgs), valid)

計算g_loss就是判別器的輸出與valid的差距,讓g_loss越來越小,就是讓gen_imgs作為判別器的輸出的概率更接近valid。就是讓gen_imgs更像真樣本。

<3> 要注意的是,這個g_loss用於去更新了生成器的權重。這個時候,判別器的權重並沒有被更新。

4、分別把假樣本和真樣本都送入到判別器。

real_loss = adverisal_loss(discriminator(real_imgs), valid)
fake_loss = adverisal_loss(discriminator(gen_imgs.detach()), fake)
d_loss = (real_loss + fake_loss) / 2
d_loss.backward()
optimizer_D.step()

real_loss是判別器去判別真樣本的輸出,讓這個輸出更接近與valid。

fake_loss是判別器去判別假樣本的輸出,讓這個輸出更接近與fake。

d_loss是前兩者的平均。

損失函數向后傳播,向后傳播,就是為了讓d_loss ---> 0。也就是讓:

real_loss ---> 0 ===> 讓判別器的輸出(真樣本概率)接近 valid

fake_loss ---> 0 ===> 讓判別器的輸出(假樣本概率)接近 fake

也就是說,讓判別器按照真假樣本的類別,分別按照不同的要求去更新參數。

5、損失函數的走向?

g_loss 越小,說明生成器生產的假樣本作為判別器的輸入的輸出(概率)越接近valid,就是假樣本越像真樣本。

d_loss越小,說明判別器越能夠將識別出真樣本和假樣本。

所以,最后是要讓g_loss更小,d_loss更大。以至於d_loss最后為0.5的時候,達到最好的效果。這個0.5的意思就是:判別器將真樣本全部識別正確,所以real_loss=0。把所有的假樣本識別錯誤,此時fake_loss = 1。最后的d_loss = 1/2。

最開始的一個疑惑就是:生成器和判別器是如何相連接,同時更新權重的?。在第3步實現了!!!


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM