U-GAT-IT筆記


由於博客園有時候公式顯示不出來,建議在https://github.com/FangYang970206/PaperNote/blob/master/GAN/UGATIT.md下載markdown文件,用typora(最強markdown編輯器)打開。

前言

介紹一下最近出的U-GAT-IT: Unsupervised Generative Attentional Networks with Adaptive
Layer-Instance Normalization for Image-to-Image Translation,首先看看這篇論文達到的效果。

1566875612246

第一行是原圖,第二行是熱力圖,第三行是生成的圖像,例子中包括人臉到動漫,馬到斑馬,貓到狗,人臉到畫等等,由於網絡采用cycle-gan的形式,所以可以互轉(動漫到人臉)。

這篇文章的效果和指標都很不錯,值得一看,首先說說題目,可以說題目是包含了文章的主要特色。有以下幾點:

  • Unsupervised Generative Networks :體現在cycle-gan的結構,不需要成對(unpair)的數據。
  • Attentional:體現在有權重的特征圖,具體做法是根據輔助分類器得到的注意圖,通過區分源域和目標域,幫助模型知道在哪里集中轉換。
  • Adaptive Layer-Instance Normalization:引入了自適應的LN和IN的混合歸一化層,幫助我們的注意力引導模型在不修改模型架構或超參數的情況下靈活控制形狀和紋理的變化量。

模型結構

整個網絡是類似cycle-gan的結構,AtoB和BtoA的生成器是一樣的,鑒別器也是一樣的,所以這里只說一個就可以了。

生成器

1566877302224

首先圖像經過一個下采樣模塊,然后經過一個殘差塊,得到編碼后的特征圖,編碼后的特征圖分兩路,一路是通過一個輔助分類器,得到有每個特征圖的權重信息,然后與另外一路編碼后的特征圖相乘,得到有注意力的特征圖。注意力特征圖依然是分兩路,一路經過一個1x1卷積和激活函數層得到黃色的a1...an特征圖,然后黃色特征圖通過全連接層得到解碼器中 Adaptive Layer-Instance Normalization層的gamma和beta,另外一路作為解碼器的輸入,經過一個自適應的殘差塊(含有Adaptive Layer-Instance Normalization)以及上采樣模塊得到生成結果。

這里說一下Adaptive Layer-Instance Normalization的具體公式:

\[\hat{a}_{I}=\frac{a-\mu_{I}}{\sqrt{\sigma_{I}^{2}+\epsilon}}, \hat{a}_{L}=\frac{a-\mu_{L}}{\sqrt{\sigma_{L}^{2}+\epsilon}} \]

上面是IN和LN的歸一化公式,然后將\(\hat{a}_{I}\)\(\hat{a}_{L}\)代入到進行合並(\(\gamma\)\(\beta\)通過外部傳入):

\[\operatorname{AdaLIN}(a, \gamma, \beta)=\gamma \cdot\left(\rho \cdot \hat{a}_{I}+(1-\rho) \cdot \hat{a}_{L}\right)+\beta \]

為了防止\(\rho\)超出[0,1]范圍,對\(\rho\)進行了區間裁剪:

\[\rho \leftarrow c l i p[0,1](\rho-\tau \Delta \rho) \]

AdaIN能很好的將內容特征轉移到樣式特征上,但AdaIN假設特征通道之間不相關,意味着樣式特征需要包括很多的內容模式,而LN則沒有這個假設,但LN不能保持原始域的內容結構,因為LN考慮的是全局統計信息,所以作者將AdaIN和LN結合起來,結合兩者的優勢,有選擇地保留或改變內容信息,有助於解決廣泛的圖像到圖像的翻譯問題。

當然,說的再多,看源碼是最直觀,最清楚的,附上注解后的pytorch官方源碼

class ResnetGenerator(nn.Module):
    def __init__(self, input_nc, output_nc, ngf=64, n_blocks=6, img_size=256, light=False):
        assert(n_blocks >= 0)
        super(ResnetGenerator, self).__init__()
        self.input_nc = input_nc    #輸入通道數 --> 3
        self.output_nc = output_nc  #輸出通道數 --> 3
        self.ngf = ngf              #第一層卷積后的通道數 --> 64
        self.n_blocks = n_blocks	#殘差塊數 --> 6
        self.img_size = img_size    #圖像size --> 256
        self.light = light          #是否使用輕量級模型

        DownBlock = []
        # 先通過一個卷積核尺寸為7的卷積層,圖片大小不變,通道數變為64
        DownBlock += [nn.ReflectionPad2d(3),
                      nn.Conv2d(input_nc, ngf, kernel_size=7, stride=1, padding=0, bias=False),
                      nn.InstanceNorm2d(ngf),
                      nn.ReLU(True)]

        # Down-Sampling --> 下采樣模塊
        n_downsampling = 2
        # 兩層下采樣,img_size縮小4倍(64),通道數擴大4倍(256)
        for i in range(n_downsampling): 
            mult = 2**i
            DownBlock += [nn.ReflectionPad2d(1),
                          nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=0, bias=False),
                          nn.InstanceNorm2d(ngf * mult * 2),
                          nn.ReLU(True)]

        # Down-Sampling Bottleneck  --> 編碼器中的殘差模塊
        mult = 2**n_downsampling
        # 6個殘差塊,尺寸和通道數都不變
        for i in range(n_blocks):
            DownBlock += [ResnetBlock(ngf * mult, use_bias=False)]

        # Class Activation Map --> 產生類別激活圖
        #接着global average pooling后的全連接層
        self.gap_fc = nn.Linear(ngf * mult, 1, bias=False)
        #接着global max pooling后的全連接層
        self.gmp_fc = nn.Linear(ngf * mult, 1, bias=False)
        #下面1x1卷積和激活函數,是為了得到兩個pooling合並后的特征圖
        self.conv1x1 = nn.Conv2d(ngf * mult * 2, ngf * mult, kernel_size=1, stride=1, bias=True)
        self.relu = nn.ReLU(True)

        # Gamma, Beta block --> 生成自適應 L-B Normalization(AdaILN)中的Gamma, Beta
        if self.light: # 確定輕量級,FC使用的是兩個256 --> 256的全連接層
            FC = [nn.Linear(ngf * mult, ngf * mult, bias=False),
                  nn.ReLU(True),
                  nn.Linear(ngf * mult, ngf * mult, bias=False),
                  nn.ReLU(True)]
        else:
            #不是輕量級,則下面的1024x1024 --> 256的全連接層和一個256 --> 256的全連接層
            FC = [nn.Linear(img_size // mult * img_size // mult * ngf * mult, ngf * mult, bias=False), # (1024x1014, 64x4) crazy
                  nn.ReLU(True),
                  nn.Linear(ngf * mult, ngf * mult, bias=False),
                  nn.ReLU(True)]
        #AdaILN中的Gamma, Beta
        self.gamma = nn.Linear(ngf * mult, ngf * mult, bias=False)
        self.beta = nn.Linear(ngf * mult, ngf * mult, bias=False)
		
        # Up-Sampling Bottleneck --> 解碼器中的自適應殘差模塊
        for i in range(n_blocks):
            setattr(self, 'UpBlock1_' + str(i+1), ResnetAdaILNBlock(ngf * mult, use_bias=False))

        # Up-Sampling --> 解碼器中的上采樣模塊
        UpBlock2 = []
        #上采樣與編碼器的下采樣對應
        for i in range(n_downsampling):
            mult = 2**(n_downsampling - i)
            UpBlock2 += [nn.Upsample(scale_factor=2, mode='nearest'),
                         nn.ReflectionPad2d(1),
                         nn.Conv2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=1, padding=0, bias=False),
                         ILN(int(ngf * mult / 2)), #注:只有自適應殘差塊使用AdaILN
                         nn.ReLU(True)]
		#最后一層卷積層,與最開始的卷積層對應
        UpBlock2 += [nn.ReflectionPad2d(3),
                     nn.Conv2d(ngf, output_nc, kernel_size=7, stride=1, padding=0, bias=False),
                     nn.Tanh()]
		
        self.DownBlock = nn.Sequential(*DownBlock) #編碼器整個模塊
        self.FC = nn.Sequential(*FC)               #生成gamma,beta的全連接層模塊
        self.UpBlock2 = nn.Sequential(*UpBlock2)   #只包含上采樣后的模塊,不包含殘差塊

    def forward(self, input):
        x = self.DownBlock(input)  #得到編碼器的輸出,對應途中encoder feature map

        gap = torch.nn.functional.adaptive_avg_pool2d(x, 1) #全局平均池化
        gap_logit = self.gap_fc(gap.view(x.shape[0], -1)) #gap的預測
        gap_weight = list(self.gap_fc.parameters())[0] #self.gap_fc的權重參數
        gap = x * gap_weight.unsqueeze(2).unsqueeze(3) #得到全局平均池化加持權重的特征圖

        gmp = torch.nn.functional.adaptive_max_pool2d(x, 1) #全局最大池化
        gmp_logit = self.gmp_fc(gmp.view(x.shape[0], -1)) #gmp的預測
        gmp_weight = list(self.gmp_fc.parameters())[0] #self.gmp_fc的權重參數
        gmp = x * gmp_weight.unsqueeze(2).unsqueeze(3) #得到全局最大池化加持權重的特征圖

        cam_logit = torch.cat([gap_logit, gmp_logit], 1) #結合gap和gmp的cam_logit預測
        x = torch.cat([gap, gmp], 1)  #結合兩種池化后的特征圖,通道數512
        x = self.relu(self.conv1x1(x)) #接入一個卷積層,通道數512轉換為256

        heatmap = torch.sum(x, dim=1, keepdim=True) #得到注意力熱力圖

        if self.light:
            x_ = torch.nn.functional.adaptive_avg_pool2d(x, 1) #輕量級則先經過一個gap
            x_ = self.FC(x_.view(x_.shape[0], -1))
        else:
            x_ = self.FC(x.view(x.shape[0], -1))
        gamma, beta = self.gamma(x_), self.beta(x_) #得到自適應gamma和beta


        for i in range(self.n_blocks):
            #將自適應gamma和beta送入到AdaILN
            x = getattr(self, 'UpBlock1_' + str(i+1))(x, gamma, beta)
        out = self.UpBlock2(x) #通過上采樣后的模塊,得到生成結果

        return out, cam_logit, heatmap #模型輸出為生成結果,cam預測以及熱力圖


class ResnetBlock(nn.Module): #編碼器中的殘差塊
    def __init__(self, dim, use_bias):
        super(ResnetBlock, self).__init__()
        conv_block = []
        conv_block += [nn.ReflectionPad2d(1),
                       nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=0, bias=use_bias),
                       nn.InstanceNorm2d(dim),
                       nn.ReLU(True)]

        conv_block += [nn.ReflectionPad2d(1),
                       nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=0, bias=use_bias),
                       nn.InstanceNorm2d(dim)]

        self.conv_block = nn.Sequential(*conv_block)

    def forward(self, x):
        out = x + self.conv_block(x)
        return out


class ResnetAdaILNBlock(nn.Module): #解碼器中的自適應殘差塊
    def __init__(self, dim, use_bias):
        super(ResnetAdaILNBlock, self).__init__()
        self.pad1 = nn.ReflectionPad2d(1)
        self.conv1 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=0, bias=use_bias)
        self.norm1 = adaILN(dim)
        self.relu1 = nn.ReLU(True)

        self.pad2 = nn.ReflectionPad2d(1)
        self.conv2 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=0, bias=use_bias)
        self.norm2 = adaILN(dim)

    def forward(self, x, gamma, beta):
        out = self.pad1(x)
        out = self.conv1(out)
        out = self.norm1(out, gamma, beta)
        out = self.relu1(out)
        out = self.pad2(out)
        out = self.conv2(out)
        out = self.norm2(out, gamma, beta)

        return out


class adaILN(nn.Module): #Adaptive Layer-Instance Normalization代碼
    def __init__(self, num_features, eps=1e-5):
        super(adaILN, self).__init__()
        self.eps = eps
        #adaILN的參數p,通過這個參數來動態調整LN和IN的占比
        self.rho = Parameter(torch.Tensor(1, num_features, 1, 1)) 
        self.rho.data.fill_(0.9)

    def forward(self, input, gamma, beta):
        #先求兩種規范化的值
        in_mean, in_var = torch.mean(torch.mean(input, dim=2, keepdim=True), dim=3, keepdim=True), torch.var(torch.var(input, dim=2, keepdim=True), dim=3, keepdim=True)
        out_in = (input - in_mean) / torch.sqrt(in_var + self.eps)
        ln_mean, ln_var = torch.mean(torch.mean(torch.mean(input, dim=1, keepdim=True), dim=2, keepdim=True), dim=3, keepdim=True), torch.var(torch.var(torch.var(input, dim=1, keepdim=True), dim=2, keepdim=True), dim=3, keepdim=True)
        out_ln = (input - ln_mean) / torch.sqrt(ln_var + self.eps)
        #合並兩種規范化(IN, LN)
        out = self.rho.expand(input.shape[0], -1, -1, -1) * out_in + (1-self.rho.expand(input.shape[0], -1, -1, -1)) * out_ln 
        #擴張得到結果
        out = out * gamma.unsqueeze(2).unsqueeze(3) + beta.unsqueeze(2).unsqueeze(3)
		
        return out


class ILN(nn.Module): #沒有加入自適應的Layer-Instance Normalization,用於上采樣
    def __init__(self, num_features, eps=1e-5):
        super(ILN, self).__init__()
        self.eps = eps
        self.rho = Parameter(torch.Tensor(1, num_features, 1, 1))
        self.gamma = Parameter(torch.Tensor(1, num_features, 1, 1))
        self.beta = Parameter(torch.Tensor(1, num_features, 1, 1))
        self.rho.data.fill_(0.0)
        self.gamma.data.fill_(1.0)
        self.beta.data.fill_(0.0)

    def forward(self, input):
        in_mean, in_var = torch.mean(torch.mean(input, dim=2, keepdim=True), dim=3, keepdim=True), torch.var(torch.var(input, dim=2, keepdim=True), dim=3, keepdim=True)
        out_in = (input - in_mean) / torch.sqrt(in_var + self.eps)
        ln_mean, ln_var = torch.mean(torch.mean(torch.mean(input, dim=1, keepdim=True), dim=2, keepdim=True), dim=3, keepdim=True), torch.var(torch.var(torch.var(input, dim=1, keepdim=True), dim=2, keepdim=True), dim=3, keepdim=True)
        out_ln = (input - ln_mean) / torch.sqrt(ln_var + self.eps)
        out = self.rho.expand(input.shape[0], -1, -1, -1) * out_in + (1-self.rho.expand(input.shape[0], -1, -1, -1)) * out_ln
        out = out * self.gamma.expand(input.shape[0], -1, -1, -1) + self.beta.expand(input.shape[0], -1, -1, -1)

        return out

生成器的代碼如上,歸結下來有以下幾個點:

  • 編碼器中沒有采用AdaILN以及ILN,而且只采用了IN,原文給出了解釋:在分類問題中,LN的性能並不比批規范化好,由於輔助分類器與生成器中的編碼器連接,為了提高輔助分類器的精度,我們使用實例規范化(批規范化,小批量大小為1)代替AdaLIN;
  • 使用類別激活圖(CAM)來得到注意力權重;
  • 通過注意力特征圖得到解碼器中AdaILN的gamma和beta;
  • 解碼器中殘差塊使用的AdaILN,而其他塊使用的是ILN;
  • 使用鏡像填充,而不是0填充;
  • 所有激活函數使用的是RELU。

鑒別器

鑒別器相比生成器,要簡單許多,結構圖如下所示:

1566896822258

具體結構與生成器類似,不過規范化使用的是譜規范化,使訓練更加穩定,收斂得更好,激活函數使用的是leakyrelu,直接上代碼:

class Discriminator(nn.Module):
    def __init__(self, input_nc, ndf=64, n_layers=5):
        super(Discriminator, self).__init__()
        model = [nn.ReflectionPad2d(1),   #第一層下采樣, 尺寸減半(128),通道數為64
                 nn.utils.spectral_norm(
                 nn.Conv2d(input_nc, ndf, kernel_size=4, stride=2, padding=0, bias=True)),
                 nn.LeakyReLU(0.2, True)]

        for i in range(1, n_layers - 2): #第二,三層下采樣,尺寸再縮4倍(32),通道數為256
            mult = 2 ** (i - 1)
            model += [nn.ReflectionPad2d(1),
                      nn.utils.spectral_norm(
                      nn.Conv2d(ndf * mult, ndf * mult * 2, kernel_size=4, stride=2, padding=0, bias=True)),
                      nn.LeakyReLU(0.2, True)]

        mult = 2 ** (n_layers - 2 - 1)
        model += [nn.ReflectionPad2d(1), # 尺寸不變(32),通道數為512
                  nn.utils.spectral_norm(
                  nn.Conv2d(ndf * mult, ndf * mult * 2, kernel_size=4, stride=1, padding=0, bias=True)),
                  nn.LeakyReLU(0.2, True)]

        # Class Activation Map, 與生成器得類別激活圖類似
        mult = 2 ** (n_layers - 2)
        self.gap_fc = nn.utils.spectral_norm(nn.Linear(ndf * mult, 1, bias=False))
        self.gmp_fc = nn.utils.spectral_norm(nn.Linear(ndf * mult, 1, bias=False))
        self.conv1x1 = nn.Conv2d(ndf * mult * 2, ndf * mult, kernel_size=1, stride=1, bias=True)
        self.leaky_relu = nn.LeakyReLU(0.2, True)

        self.pad = nn.ReflectionPad2d(1)
        self.conv = nn.utils.spectral_norm(
            nn.Conv2d(ndf * mult, 1, kernel_size=4, stride=1, padding=0, bias=False))

        self.model = nn.Sequential(*model)

    def forward(self, input):
        x = self.model(input)

        gap = torch.nn.functional.adaptive_avg_pool2d(x, 1)
        gap_logit = self.gap_fc(gap.view(x.shape[0], -1))
        gap_weight = list(self.gap_fc.parameters())[0]
        gap = x * gap_weight.unsqueeze(2).unsqueeze(3)

        gmp = torch.nn.functional.adaptive_max_pool2d(x, 1)
        gmp_logit = self.gmp_fc(gmp.view(x.shape[0], -1))
        gmp_weight = list(self.gmp_fc.parameters())[0]
        gmp = x * gmp_weight.unsqueeze(2).unsqueeze(3)

        cam_logit = torch.cat([gap_logit, gmp_logit], 1)
        x = torch.cat([gap, gmp], 1)
        x = self.leaky_relu(self.conv1x1(x))

        heatmap = torch.sum(x, dim=1, keepdim=True)

        x = self.pad(x)
        out = self.conv(x) #輸出大小是32x32,其他與生成器類似

        return out, cam_logit, heatmap

損失函數

損失函數總共有四個,分別是Adversarial loss, Cycle loss, Identity loss以及CAM loss。

Adversarial loss:

\[\begin{aligned} L_{g a n}^{s \rightarrow t} &=\left(\mathbb{E}_{x \sim X_{t}}\left[\left(D_{t}(x)\right)^{2}\right]\right.\\ &\left.+E_{x \sim X_{s}}\left[\left(1-D_{t}\left(G_{s \rightarrow t}(x)\right)\right)^{2}\right]\right) \end{aligned} \]

對抗損失沒有采用原始的log函數,使用的是MSE.

Cycle loss:

\[L_{c y c l e}^{s \rightarrow t}=\mathbb{E}_{x \sim X_{s}}\left[\left|x-G_{t \rightarrow s}\left(G_{s \rightarrow t}(x)\right)\right|_{1}\right] \]

cycle-gan架構下的環一致性loss,A翻譯到B,然后B翻譯到A‘,A和A’需要相同,loss采用的是L1loss.

Identity loss:

\[L_{i d e n t i t y}^{s \rightarrow t}=\mathbb{E}_{x \sim X_{t}}\left[\left|x-G_{s \rightarrow t}(x)\right|_{1}\right] \]

Identity loss保證了輸入圖像A和輸出圖像B的顏色分布是相似的.

CAM loss

生成器和鑒別器的CAM loss有些不同:

生成器CAM loss,采用的是BCE_loss:

\[\begin{aligned} L_{\text {cam}}^{s \rightarrow t} &=-\left(\mathbb{E}_{x \sim X_{s}}\left[\log \left(\eta_{s}(x)\right)\right]\right.\\ &+\mathbb{E}_{x \sim X_{t}}\left[\log \left(1-\eta_{s}(x)\right)\right] \end{aligned} \]

鑒別器CAM loss, 采用的是MSE,沒有論文中log函數,是論文公式打印出錯了,詳細見issue

\[\begin{aligned} L_{c a m}^{D_{t}} &=\mathbb{E}_{x \sim X_{t}}\left[\left(\eta_{D_{t}}(x)\right)^{2}\right] \\ &+\mathbb{E}_{x \sim X_{s}}\left[ \left(1-\eta_{D_{t}}\left(G_{s \rightarrow t}(x)\right)\right)^{2}\right] \end{aligned} \]

用CAM的原因是利用輔助分類器\(\eta_{s}\)\(\eta_{D_{t}}\)的信息,給定一個圖像\(x \in\left\{X_{s}, X_{t}\right\}\)\(G_{s \rightarrow t}\)\(D_{t}\)了解它們需要改進的地方,或者在當前狀態下兩個域之間的最大區別是什么。

實驗結果

下面論文中的效果對比圖,確實有效地控制形狀和紋理,沒有發生較大地畸變,很不錯。

下面的是實驗指標:

1566977116709

可以看到在單個物體的翻譯效果很好,特別是在selfie2anime,由於喜歡動漫,看到效果圖,這才仔細看了哈哈哈。

結語

作者在https://github.com/taki0112/UGATIT/issues/6給出了selfie2anime數據集以及他們的預訓練模型,想生成自己的動漫頭像,盤起來吧!😄


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM