對抗生成網絡(GAN)中損失函數的理解
最近開始接觸對抗生產網絡,目地是用GAN生成一些假樣本,去解決樣本不平衡的問題。
看了兩天GAN的代碼,沒有太多特別的地方,因為之前看論文的時候就已經知道大體的結構。但是唯一沒有搞清除的就是:生成器和判別器的損失函數,以及損失函數是怎么向后傳播,去更新權重的。
簡述一些GAN的訓練過程:
1、先定義一個標簽:valid = 1,fake = 0。當然這兩個值的維度是按照數據的輸出來看的。再定義了兩個優化器。用於生成器和判別器。
2、隨機生成一個噪聲z。將z作為生成器的輸入,輸出gen_imgs(假樣本)。
3、計算生成器的損失:(這里就很有趣了)
定義:生成器的損失為g_loss。損失函數為adverisal_loss()。判別器為discriminator()。
g_loss = adverisal_loss(discriminator(gen_imgs), valid)
g_loss.backward()
optimizer_G.step()
可以看出來,g_loss是根據一個輸出(將生成的樣本作為輸入的判別器的輸出)與valid的一個損失。這是個什么意思?這里為什么扯進來了判別器?
用白話來一步一步解釋:
<1> discriminator(gen_imgs) 的輸出是個什么?
既然是判別器,意思就是判別gen_imgs是不是真樣本。如果是用softmax輸出,是一個概率,為真樣本的概率。
<2> g_loss = adverisal_loss(discriminator(gen_imgs), valid)
計算g_loss就是判別器的輸出與valid的差距,讓g_loss越來越小,就是讓gen_imgs作為判別器的輸出的概率更接近valid。就是讓gen_imgs更像真樣本。
<3> 要注意的是,這個g_loss用於去更新了生成器的權重。這個時候,判別器的權重並沒有被更新。
4、分別把假樣本和真樣本都送入到判別器。
real_loss = adverisal_loss(discriminator(real_imgs), valid)
fake_loss = adverisal_loss(discriminator(gen_imgs.detach()), fake)
d_loss = (real_loss + fake_loss) / 2
d_loss.backward()
optimizer_D.step()
real_loss是判別器去判別真樣本的輸出,讓這個輸出更接近與valid。
fake_loss是判別器去判別假樣本的輸出,讓這個輸出更接近與fake。
d_loss是前兩者的平均。
損失函數向后傳播,向后傳播,就是為了讓d_loss ---> 0。也就是讓:
real_loss ---> 0 ===> 讓判別器的輸出(真樣本概率)接近 valid
fake_loss ---> 0 ===> 讓判別器的輸出(假樣本概率)接近 fake
也就是說,讓判別器按照真假樣本的類別,分別按照不同的要求去更新參數。
5、損失函數的走向?
g_loss 越小,說明生成器生產的假樣本作為判別器的輸入的輸出(概率)越接近valid,就是假樣本越像真樣本。
d_loss越小,說明判別器越能夠將識別出真樣本和假樣本。
所以,最后是要讓g_loss更小,d_loss更大。以至於d_loss最后為0.5的時候,達到最好的效果。這個0.5的意思就是:判別器將真樣本全部識別正確,所以real_loss=0。把所有的假樣本識別錯誤,此時fake_loss = 1。最后的d_loss = 1/2。
最開始的一個疑惑就是:生成器和判別器是如何相連接,同時更新權重的?。在第3步實現了!!!
