MAE源代碼理解 part1 : 調試理解法



                git官方鏈接: GitHub - facebookresearch/mae: PyTorch implementation of MAE https//arxiv.org/abs/2111.06377

 

下了MAE代碼 完全看不懂 我要一步一步來 把這篇代碼給全部理解了 。我自己覺得看大神代碼很有用。 這篇文章當筆記用。

一,跑示例:

怎么說 一上來肯定是把demo里的代碼拿出來跑一跑。但是會遇到問題。 下面時demo的代碼。 第一個問題是

TypeError: __init__() got an unexpected keyword argument 'qk_scale'

說函數沒這個參數 那很簡單 找到位置 刪掉就行 為啥我敢刪 就是因為他的值是 None ,直接刪就行

第二個問題是 我一開始把

這三個模型當成了預訓練模型 , 下面左就是得到的結果 這啥啊 還原了個寂寞 。 想了半天kaiming是不是錯了 ,再想了半天kaiming怎么會錯 ,才發現預訓練模型藏在鏈接里。下面這三個只是他開始訓練時使用的預訓練模型。

 

鏈接在demo里找到  兩個large的 模型參數如下  跑的結果如上右 對嘛 

https://dl.fbaipublicfiles.com/mae/visualize/mae_visualize_vit_large.pth

https://dl.fbaipublicfiles.com/mae/visualize/mae_visualize_vit_large_ganloss.pth

復現結束了 (bushi) 

        終於把演示跑通了。

 

 

  2 畫圖

調試這個方法可太神了,我們上面跑通了demo  就讓我們跟着demo一覽模型全貌吧!

這段 獲取圖像並且歸一化  然后用plt畫出來  這里是先歸一化 畫圖時再返回回來。 

(吐槽  : 我不理解 為什么要先歸一化 再回來 再畫圖  多此一舉? 我直接show img 不香嗎)

# load an image
img_url = 'https://user-images.githubusercontent.com/11435359/147738734-196fd92f-9260-48d5-ba7e-bf103d29364d.jpg' # fox, from ILSVRC2012_val_00046145
# img_url = 'https://user-images.githubusercontent.com/11435359/147743081-0428eecf-89e5-4e07-8da5-a30fd73cc0ba.jpg' # cucumber, from ILSVRC2012_val_00047851
img = Image.open(requests.get(img_url, stream=True).raw)
#raw是一種格式 stream 是確定能下再下。(比如會事先確定內存)
img = img.resize((224, 224))
img = np.array(img) / 255.

assert img.shape == (224, 224, 3)

# normalize by ImageNet mean and std
img = img - imagenet_mean
img = img / imagenet_std

plt.rcParams['figure.figsize'] = [5, 5]   #設置畫布尺寸
show_image(torch.tensor(img))

def show_image(image, title=''):
    # image is [H, W, 3]
    assert image.shape[2] == 3
    plt.imshow(torch.clip((image * imagenet_std + imagenet_mean) * 255, 0, 255).int())
    #剛才歸一化了 現在返回 記得clip防止越界 int防止小數  因為像素都是整數   imshow竟然可以讀張量
    plt.title(title, fontsize=16)
    plt.show()
    plt.axis('off')
    return

3 載入模型  

3.1准備模型

chkpt_dir = 'model_save/mae_visualize_vit_large.pth'
model_mae = prepare_model(chkpt_dir, 'mae_vit_large_patch16')
print('Model loaded.')

會進入准備模型的函數里 

def prepare_model(chkpt_dir, arch='mae_vit_large_patch16'):
    # build model
    model = getattr(models_mae, arch)()
    # load model
    checkpoint = torch.load(chkpt_dir, map_location='cpu')
    msg = model.load_state_dict(checkpoint['model'], strict=False)
    print(msg)

    return model

對於第一局 getattr(models_mae,arch):   是取models_mae模塊里的arch  而這個arch是什么 下圖可以看到是一個函數 而且是一個沒帶括號的函數 (我不理解 ) 所以get后要補一個括號

        

然后我們進入這個函數, 可以看到這個函數了 哦~ 是一個獲取模型的函數 大 中小模型有三個不同的函數 不同函數的參數不一樣罷了。

 然后就是一個大工程了 我們進這個模型內部看一看。

3.2.1_模型內部 

        模型代碼太大了 我就不貼整個的了 我一部分一部分的貼。

3.2.1.1  編碼器模塊

from timm.models.vision_transformer import PatchEmbed, Block


self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
#patch_size 應該是一個圖片分出來的 一張有多大  inchans 一般都是3 圖片層數嘛
# embed——dim 這個是編出來的特征維度 1024


num_patches = self.patch_embed.num_patches
##num_pathches 大小是x*y 就是圖片分成x*y份num_patches = (224/patch_size)**2 = 14 **2 = 196

這個編碼 來自於VIT的編碼, 然而我並沒有看過VIT的代碼是什么樣子的 。這篇里先不寫 ,等到下一篇文章 我就遍歷進這個編碼函數里 看看是什么東西。 我們就記住 有一個編碼的函數 似乎是吧圖片 變成一串特征碼  

self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim),         
    requires_grad=False)  # fixed sin-cos embedding

cls令牌 加入 位置編碼加入  nn.patameter這個函數 就是將一個不可訓練的張量或者矩陣 轉換為模型內可以訓練的參數。 (想寫一個要訓練的參數 又不是官方的那些層 ,終於知道方法啦)。cls_token大小是 (1,1,1024) 位置編碼是 (1,197,1024) 為啥是197呢 ?應該是為了跟嵌入cls后的編碼大小保持一致 然后可以cat  我猜。

        self.blocks = nn.ModuleList([
            Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer)
            for i in range(depth)])
        self.norm = norm_layer(embed_dim)

這里的 block 就是VIT里的那個block  這個block也等到VIT代碼時再講

這里有幾個他們用的小trick

        

nn.LayerNorm   #這個表示在channel 上做歸一化 
nn.batchNorm  #這個是在batch上歸一化
DropPath  # 這個也是一種與dropout不同的 drop方法
nn.GELU   #一種激活函數 

nn.ModuleList  其實就是一個列表 把一些塊放在這個列表里  與普通列表不同的是 普通的列表不會得到訓練 。 這里就是放了24個自注意力塊  每個塊有12個頭   。以上就是編碼器用到的模塊。

3.2.1.2 解碼模塊

下面是解碼器。 

        self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)
            # 一個fc層 1024到512

        self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
            #一個mask編碼 (1,1,512)
        self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1,         
            decoder_embed_dim), requires_grad=False)  # fixed  sin-cos embedding
              #一個位置編碼 而且不訓練 (1,197,512)  為什么不訓練啊?
        
        self.decoder_blocks = nn.ModuleList([
            Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer)
            for i in range(decoder_depth)])

        self.decoder_norm = norm_layer(decoder_embed_dim)

        self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size**2 * in_chans, bias=True) # decoder to patch
        #預測層  512 到   256*3 (這個也不到224*224*3啊)

解碼器的注意力層只有8層 但也是12頭的  輸入是512維 

3.2.1.3 初始化模塊 

3.2.1.3.1 找位置編碼 

        self.norm_pix_loss = norm_pix_loss

        self.initialize_weights()

第一個的值是false 等會看看有啥用  第二個是一個函數 我們進去看看 。

       pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True)
        self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))

初始化 第一步 是一個位置編碼函數 ,我們進入這個編碼函數去看 

def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):  
    #embed_dim = 1024 是位置的最后一維 gridSize是每個小patch的長寬 也就是14
    
  
    grid_h = np.arange(grid_size, dtype=np.float32)
    grid_w = np.arange(grid_size, dtype=np.float32)
      #生成兩個坐標系 14*14的 


    grid = np.meshgrid(grid_w, grid_h)  # here w goes first
    #這就是一個坐標系了 不過誰是x 誰是y還要看看 
    grid = np.stack(grid, axis=0)
    #  生成了 兩個網格。 每個都是14*14  grid現在是(2,14,14)


    grid = grid.reshape([2, 1, grid_size, grid_size])  
    #(2,1,14,14)


   

然后繼續進入下層函數 我們繼續看 。

    pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)


def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
    assert embed_dim % 2 == 0

    # use half of dimensions to encode grid_h
    emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])  # (H*W, D/2)
    emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])  # (H*W, D/2)

    emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
    return emb

再進入下層函數 。

def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):

    

    """
    embed_dim: output dimension for each position   這里只有512
    pos: a list of positions to be encoded: size (M,)    #這里是(1,14,14) 相當於一個通道
    out: (M, D)
    """
    assert embed_dim % 2 == 0
    omega = np.arange(embed_dim // 2, dtype=np.float)
    #    (1,2,3,4.。。。256)
    omega /= embed_dim / 2.
    #這一步是歸一化

    omega = 1. / 10000**omega  # (D/2,)
    ##有點像做了個反向 本來是0到1 現在是1到0

    pos = pos.reshape(-1)  # (M,)
    #1,14,14 變成了 196  形式是0到13循環14次 
    out = np.einsum('m,d->md', pos, omega)  # (M, D/2), outer product
        #這里是外集 就是一列乘一行 相當於  out就變成 (196, 256)的矩陣了。

    emb_sin = np.sin(out) # (M, D/2)
    emb_cos = np.cos(out) # (M, D/2)
    
    emb = np.concatenate([emb_sin, emb_cos], axis=1)  # (M, D)

    #對所有值取sin 和cos 之后con起來  但注意維度是1 也就是196*512 前半段是sin 后半段cos
    return emb

下層函數返回后 再次拼起來 變成 196 *1024  這個位置編碼真可謂是歷盡艱辛 。我們來看 他是怎么來的 。首先 196, 1024分前后兩段。看前半段 。 先做個(256,1)長的矩陣 分布再1,256 表示位置 之后呢 再反向后與網格(14*14)拉平后的值做一個外積 這個網格也是位置信息。之后sin 和cos都上 得到兩個位置編碼。 再拼起來 得到一個維度的編碼 。 再把兩個維度拼起來得到整體的位置編碼。 

    if cls_token:
        pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
    return pos_embed

這里是  將196 1041 , 變成(197,1024)  拼出CLS那一維。

3.2.1.3.2回到初始化

        self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
        #將numpy變為tensor 后 轉float32 再擴充維度為(1,197,1024) 就得到了編碼器的位置編碼
  decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True)
        self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0))

解碼器的位置編碼  (1,197,512) 還是比編碼器少了一半

        w = self.patch_embed.proj.weight.data
        torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))


        torch.nn.init.normal_(self.cls_token, std=.02)
        torch.nn.init.normal_(self.mask_token, std=.02)

這個w是取出weight層的權重值。 正好可以看出 w的大小是 (1024,3,16,16) 1024是輸出維度 3是輸入維度 。相當於一個卷積 ? 然后參數進行一個初始化 統一於 (1024, 3*16*16)正太分布 

 mask 和 cls 也要初始化 。

        

self.apply(self._init_weights)

初始化其他層 self.apply應該是對遍歷模型 對每一個模塊 使用后面這個函數 我們進入初始化權重函數看一看 ,

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            # we use xavier_uniform following official JAX ViT:
            torch.nn.init.xavier_uniform_(m.weight)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

可以看到是如何初始化的 全連接層的 權重使用xavier的均勻分布  偏置設為0 

layer歸一化層 的偏置為0 權重為1 

過程中可以看到對24個注意力層都初始化 而且注意力層里也有各種各樣的linear層。

3.2.1.3.3 初始化完成

至此 模型的初始化完成了 我們得到了這個模型。從這些步驟里 我們可以大概看到模型是什么樣子的 , 有一個編碼器模塊 和一個解碼器模塊。 編碼器模塊有24層深的16頭自注意力模塊。 還有一些位置編碼和 cls 編碼   而解碼器只是多了一個mask編碼,而且維度會與編碼器不一樣。

3.3 模型准備完成。 

checkpoint = torch.load(chkpt_dir, map_location='cpu')

這個chkpt_dir 也就是下載下來的預訓練模型 大概應該只是參數  所以需要下面這句 模型載入參數

這里這個strict 意思是 如果與預訓練有的層 就使用預訓練的參數  模型里 預訓練沒有的層 就普通初始化。 

msg = model.load_state_dict(checkpoint['model'], strict=False)

return model

msg 記錄加載的結果  得到完全體模型。

 

4處理圖片 

模型准備好了 我們開始用模型處理一個圖片看看 。 

4.1數據准備 

torch.manual_seed(2)   #固定隨機數種子
print('MAE with pixel reconstruction:')
run_one_image(img, model_mae)

 我們進入了 run_ONE_image函數內部 

    x = torch.tensor(img)

    # make it a batch-like
    x = x.unsqueeze(dim=0)
    x = torch.einsum('nhwc->nchw', x)

這里顯示了怎么把一個 圖片 做成一個batch  第三個einsum 也可以用

torch.transpose() 這個函數來  就是一個維度的轉換嘛 把那個3 提到第二維上來。 不過he他們確實精妙  大佬。
    loss, y, mask = model(x.float(), mask_ratio=0.75)

進入模型運行了 。  從模型返回的是loss 預測值 和mask 我們進模型內部看看  注意模型中運算的值都是float32 格式的 。 

latent, mask, ids_restore = self.forward_encoder(imgs, mask_ratio)

進froward第一句 就是這一句  我們接下來進入前向編碼器里看一看 。 

4.2編碼步驟 

    def forward_encoder(self, x, mask_ratio):    
        # embed patches

        x = self.patch_embed(x)  #x:(1,3,224,224)->(1,196,1024)   14*14個片編碼

        # add pos embed w/o cls token
        x = x + self.pos_embed[:, 1:, :] # pos是1,197,1024 這里不要0的cls位置  位置信息是直接加到片編碼上的  和我的想法很不一樣  這樣加上來真的會有效果么 。

        # masking: length -> length * mask_ratio
        x, mask, ids_restore = self.random_masking(x, mask_ratio)
        
    def random_masking(self, x, mask_ratio):
        """
        Perform per-sample random masking by per-sample shuffling.
        Per-sample shuffling is done by argsort random noise.
        x: [N, L, D], sequence
        """
        N, L, D = x.shape  # batch, length, dim
        len_keep = int(L * (1 - mask_ratio))  #計算需要剩余多少片
        
        noise = torch.rand(N, L, device=x.device)  # noise in [0, 1] noise(1,196)
        
        # sort noise for each sample
        ids_shuffle = torch.argsort(noise, dim=1)  # ascend: small is keep, large is remove
        # 是對noise的值進行排序  ids_shuffle得到的是下標值。

        ids_restore = torch.argsort(ids_shuffle, dim=1)  #對排序后得到的下標 再排序?這一步#我非常的不懂  后面看 
        # keep the first subset
        ids_keep = ids_shuffle[:, :len_keep]    #保持噪聲值小的那一堆?
        x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
#這個gather 就是在 x的 dim維 挑index的數。  但是好奇的是 這一串下來 不就是隨機挑嗎?
# index的維度是 (1,49,1024)X是(1,196,1024) x_masked 是(1,49,1024)

        # generate the binary mask: 0 is keep, 1 is remove
        mask = torch.ones([N, L], device=x.device)
        mask[:, :len_keep] = 0
        #mask 是(1,196) 其中前49都是0 后面都是1 
        
        # unshuffle to get the binary mask
        mask = torch.gather(mask, dim=1, index=ids_restore)
        #到這里終於明白了  這個ids_REStore的作用 就是把mask當成noise 然后把mask按照#restore的位置排序  這樣得到的mask就是一個  有mask的地方為1 沒mask的地方為0的二維張量。

        return x_masked, mask, ids_restore

這里的mask這里非常難以理解 所以我舉個例子 來看看 。 

首先 noise是隨機生成的  比如說是 noise = [2,0,3,1] 

                           然后 排序argsort: shuffle = [1,3,0,2]    到這里 是為了生成隨機數  我們取前兩個 也就是隨機出來的1,3 作為mask的下標 

                        對shuffle排序       :  restore = [2,0,3,1]

                      mask = [0,0,1,1]  我們根據restore對mask取數  得到[ 1,0,1,0]  下標1,3處就是0.            其實你可以把mask和shuffle看成一樣的 你用restore對shuffle 取數 得到【0,1,2,3】發現是排序好的 。 對【1,0,1,0】取數 得到[0,0,1,1]兩個是對應起來的。

處理cls 

        cls_token = self.cls_token + self.pos_embed[:, :1, :]
        #cls加上位置信息 
        cls_tokens = cls_token.expand(x.shape[0], -1, -1)
          # 這一句是為了防止批量的 也就是擴充復制 如果x的batch為N  cls也要復制N份


        x = torch.cat((cls_tokens, x), dim=1)

        #x:(1,50,1024) ->(1,50,1024)   原來是擴充在片數這一維。

這里x要經歷24個多頭自注意力的磨練  然后歸一化。

        for blk in self.blocks:
            x = blk(x)
        x = self.norm(x)

4.3解碼步驟 

回歸forward  來到第二局 解碼

pred = self.forward_decoder(latent, ids_restore)  # [N, L, p*p*3]
    def forward_decoder(self, x, ids_restore):
        # embed tokens
        x = self.decoder_embed(x)

        #x  (1,50,1024) ->(1,50,512)
        
        # append mask tokens to sequence
        mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
        
        ##ids_restore.shape[1] + 1 - x.shape[1] =196+1-50 =147也就是cls加片數減x=需要遮蓋數
        #self.maskroken.shape = (1,1,512)  mask_tokens = (1,147,512) repeate是幾就復制幾份

        x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1)  # no cls token cls辛辛苦苦一輩子 
         #就這樣沒了  我還沒看到你作用呢 麻煩半天  這里就是完成了 x和mask拼接后的X_


        x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]))  # unshuffle    排序回去 按照 mask  index.shape = (1,196,512)
        x = torch.cat([x[:, :1, :], x_], dim=1)  # append cls token  無語

        # add pos embed
        x = x + self.decoder_pos_embed 

        # apply Transformer blocks
        for blk in self.decoder_blocks:
            x = blk(x)           
        x = self.decoder_norm(x)     #就八個 寒酸 

        # predictor projection
        x = self.decoder_pred(x)
        #### x (1,197.512) -> (1,197,768)


        # remove cls token  cls:你有毛病是吧。 
        x = x[:, 1:, :]

        return x

得到了模型預測的圖像結果  

4.4 loss探索 

下一步是loss

        pred = self.forward_decoder(latent, ids_restore)  # [N, L, p*p*3]
        loss = self.forward_loss(imgs, pred, mask)
        target = self.patchify(imgs)

首先進入這個函數  p是一個小圖的大小 hw分別是yx方向圖的個數  都是14 

    def patchify(self, imgs):
        """
        imgs: (N, 3, H, W)
        x: (N, L, patch_size**2 *3)
        """
        p = self.patch_embed.patch_size[0]
        assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0

        h = w = imgs.shape[2] // p
        x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
        x = torch.einsum('nchpwq->nhwpqc', x)
        x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3))
        return x

x 是(1,3,14,16,14,16) -(1,14,14,16,16,3)

然后reshape (1,14,14,16,16,3) -》(1,196,768) 此中過程 不足為外人道也 鬼知道你咋變的啊 。

target = self.patchify(imgs)   這句就是把原來的圖片 也編輯成(1,196,768)大小的 
        if self.norm_pix_loss:
            mean = target.mean(dim=-1, keepdim=True)
            var = target.var(dim=-1, keepdim=True)
            target = (target - mean) / (var + 1.e-6)**.5

這個歸一化 沒進去 

 可能因為本來已經歸過了 

 

        loss = (pred - target) ** 2
        loss = loss.mean(dim=-1)  # [N, L], mean loss per patch

loss是像素差平方  然后對最后一維求平均 變成了 (1,196) 也就是每一個小pat 一個loss

        loss = (loss * mask).sum() / mask.sum()  # mean loss on removed patches
        return loss

mask在相應沒有遮蓋的地方是0 所以就是只有遮蓋的地方才求loss  返回loss值。回到run

4.5 畫圖 

    loss, y, mask = model(x.float(), mask_ratio=0.75)
    y = model.unpatchify(y)

進圖unpatchify 根據這個名字 可以看出是吧patch 還原成大圖 。

    def unpatchify(self, x):
        """
        x: (N, L, patch_size**2 *3)
        imgs: (N, 3, H, W)
        """
        p = self.patch_embed.patch_size[0]
        h = w = int(x.shape[1]**.5)
        assert h * w == x.shape[1]
        
        x = x.reshape(shape=(x.shape[0], h, w, p, p, 3))
        x = torch.einsum('nhwpqc->nchpwq', x)
        imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p))
        return imgs

p 16 h w, 14,14   

x (1,196,768) -> (1,14,14,16,16,3) ->(1,3,14,16,14,16)  ->imgs(1,3,224,224) 

#我忽然想明白了 這里不用知道里面是怎么變化的 只需要操持一致即可  計算機自己就會把他們對應起來 又不用自己管。 

回到上面來 

    loss, y, mask = model(x.float(), mask_ratio=0.75)
    y = model.unpatchify(y)
    y = torch.einsum('nchw->nhwc', y).detach().cpu()

y(1,3,224,224)- 》(1,224,224,3)

    # visualize the mask
    mask = mask.detach()
    mask = mask.unsqueeze(-1).repeat(1, 1, model.patch_embed.patch_size[0]**2 *3)  # (N, H*W, p*p*3)
    mask = model.unpatchify(mask)  # 1 is removing, 0 is keeping
    mask = torch.einsum('nchw->nhwc', mask).detach().cpu()

mask:(1,196 )  ->(1,196,768) ->(1,3,224,224)  ->(1,224,224,3) 

    x = torch.einsum('nchw->nhwc', x)

    # masked image
    im_masked = x * (1 - mask)

    # MAE reconstruction pasted with visible patches
    im_paste = x * (1 - mask) + y * mask

x (1,3,224,224) ->(1,224,224,3)

1-mask  就是本來是0的 就是沒遮蓋的變成1 遮蓋的變成0 與x相乘 就得到遮蓋圖片 。

im_paste = x * (1 - mask) + y * mask  遮蓋的圖片 加上預測的Y與mask相乘 。 因為mask遮蓋的地方是1 所以直接相乘 

至此得到所有需要畫的圖像。, 

    # make the plt figure larger
    plt.rcParams['figure.figsize'] = [24, 24]

    plt.subplot(1, 4, 1)
    show_image(x[0], "original")

    plt.subplot(1, 4, 2)
    show_image(im_masked[0], "masked")

    plt.subplot(1, 4, 3)
    show_image(y[0], "reconstruction")

    plt.subplot(1, 4, 4)
    show_image(im_paste[0], "reconstruction + visible")

    plt.show()

無語淚凝噎 為啥圖不是一塊出來的 ????

 

 

 

原來是因為我改了代碼 

 

ok  完畢啦 演示結束 改天看其他模塊 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM