Pytorch框架學習---(6)hook函數和CAM類激活圖


本節簡單總結Pytorch中hook函數,CAM算法生成注意力圖【文中思維導圖采用MindMaster軟件
注意:對於真正運用CAM的代碼,本人后續隨着需要,再逐步更新。

1.hook函數

(1)定義

  不改變主體(前向、后向傳播等)情況下,實現額外的功能,如在backward之后,仍然可以得到特征圖和非葉子節點的梯度,即便它們被釋放。

(2)方法

節省精力, 由於網上已經有人對這4和hook函數總結的很好,故在此引用,不再復寫。

  這里我們直接來舉一個例子,使用hook函數可視化所有層的特征圖,即調用上面的register_forward_hook獲取網絡層的輸出:

# 注冊hook
    fmap_dict = dict()
    for name, sub_module in alexnet.named_modules():  # 如果是named_children()則是返回Sequential本身features
        # print(sub_module)   # sub_module  Sequential本身features以及內部所有的網絡層features.0

        if isinstance(sub_module, nn.Conv2d):
            key_name = str(sub_module.weight.shape)
            fmap_dict.setdefault(key_name, list())   # 構建字典中key value對

            n1, n2 = name.split(".")  # features.0,  為nn.Sequential

            def hook_func(module, i, o):
                key_name = str(module.weight.shape)
                fmap_dict[key_name].append(o)  # 索引名字,添加特征圖
                # print("famp_dict:{}".format(fmap_dict))

            alexnet._modules[n1]._modules[n2].register_forward_hook(hook_func)

    # forward
    output = alexnet(img_tensor)

    # add image
    for layer_name, fmap_list in fmap_dict.items():  # 返回一個可迭代的列表
        fmap = fmap_list[0]  # 把list中元素取出
        fmap.transpose_(0, 1)

        nrow = int(np.sqrt(fmap.shape[0]))
        fmap_grid = vutils.make_grid(fmap, normalize=True, scale_each=True, nrow=nrow)
        writer.add_image('feature map in {}'.format(layer_name), fmap_grid, global_step=0)

  對每一個卷積層得到的特征圖,作tensorboard可視化:

注意:這里可視化卷積層,但是由於卷積層后面接的是激活函數relu,其中relu(inplace=True)原位操作,會對卷積層的輸出做一定的改變。

2.CAM(Class Activation Map)類激活圖

  啥話先不說,直接上圖!!!原來這個就是CAM算法出來的,當判別網絡將圖片歸為“貓”這個類別時,紅色代表網絡注意的地方,藍色則是沒有注意的地方:

(1)原始CAM

  最后一層卷積得到的特征圖,經過全局平均池化GAP,得到對應神經元向量,全連接層的權重,即是CAM對特征圖加權的權重,經過加權之后的特征圖即是最終類似注意力的激活圖。

局限性:最后必須是GAP,需要改動原始網絡並重新訓練,因而改進版Grad-CAM上線

(2)Grad-CAM(利用特征圖的梯度,作為加權權重)

  對特征圖梯度做平均,得到n個特征圖對應的n個平均梯度,將其作為CAM權重。

實戰代碼如下參考:github,后續用到CAM時,再放入自己項目的激活圖展示代碼。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM