对抗生成网络(GAN)中损失函数的理解


对抗生成网络(GAN)中损失函数的理解

最近开始接触对抗生产网络,目地是用GAN生成一些假样本,去解决样本不平衡的问题。

看了两天GAN的代码,没有太多特别的地方,因为之前看论文的时候就已经知道大体的结构。但是唯一没有搞清除的就是:生成器和判别器的损失函数,以及损失函数是怎么向后传播,去更新权重的。

简述一些GAN的训练过程:

1、先定义一个标签:valid = 1,fake = 0。当然这两个值的维度是按照数据的输出来看的。再定义了两个优化器。用于生成器和判别器。

2、随机生成一个噪声z。将z作为生成器的输入,输出gen_imgs(假样本)。

3、计算生成器的损失:(这里就很有趣了)

定义:生成器的损失为g_loss。损失函数为adverisal_loss()。判别器为discriminator()。

g_loss = adverisal_loss(discriminator(gen_imgs), valid)
g_loss.backward()
optimizer_G.step()

可以看出来,g_loss是根据一个输出(将生成的样本作为输入的判别器的输出)与valid的一个损失。这是个什么意思?这里为什么扯进来了判别器?

用白话来一步一步解释:

<1> discriminator(gen_imgs) 的输出是个什么?

既然是判别器,意思就是判别gen_imgs是不是真样本。如果是用softmax输出,是一个概率,为真样本的概率。

<2> g_loss = adverisal_loss(discriminator(gen_imgs), valid)

计算g_loss就是判别器的输出与valid的差距,让g_loss越来越小,就是让gen_imgs作为判别器的输出的概率更接近valid。就是让gen_imgs更像真样本。

<3> 要注意的是,这个g_loss用于去更新了生成器的权重。这个时候,判别器的权重并没有被更新。

4、分别把假样本和真样本都送入到判别器。

real_loss = adverisal_loss(discriminator(real_imgs), valid)
fake_loss = adverisal_loss(discriminator(gen_imgs.detach()), fake)
d_loss = (real_loss + fake_loss) / 2
d_loss.backward()
optimizer_D.step()

real_loss是判别器去判别真样本的输出,让这个输出更接近与valid。

fake_loss是判别器去判别假样本的输出,让这个输出更接近与fake。

d_loss是前两者的平均。

损失函数向后传播,向后传播,就是为了让d_loss ---> 0。也就是让:

real_loss ---> 0 ===> 让判别器的输出(真样本概率)接近 valid

fake_loss ---> 0 ===> 让判别器的输出(假样本概率)接近 fake

也就是说,让判别器按照真假样本的类别,分别按照不同的要求去更新参数。

5、损失函数的走向?

g_loss 越小,说明生成器生产的假样本作为判别器的输入的输出(概率)越接近valid,就是假样本越像真样本。

d_loss越小,说明判别器越能够将识别出真样本和假样本。

所以,最后是要让g_loss更小,d_loss更大。以至于d_loss最后为0.5的时候,达到最好的效果。这个0.5的意思就是:判别器将真样本全部识别正确,所以real_loss=0。把所有的假样本识别错误,此时fake_loss = 1。最后的d_loss = 1/2。

最开始的一个疑惑就是:生成器和判别器是如何相连接,同时更新权重的?。在第3步实现了!!!


免责声明!

本站转载的文章为个人学习借鉴使用,本站对版权不负任何法律责任。如果侵犯了您的隐私权益,请联系本站邮箱yoyou2525@163.com删除。



 
粤ICP备18138465号  © 2018-2025 CODEPRJ.COM