介紹
在圖像識別和目標分類領域往往存在一些關於圖像中虛假相關性的問題,最典型的如將圖像中識別的主體(object)和背景(background)之間的相關性考慮成為識別主題類別的一個主要特征。如下圖所示,注意力模型將ground
作為一個判斷為鳥類的標簽,在預測地上的熊的時候就做出了錯誤的判斷。這種在不一樣的場景下的圖片,反而往往是在特殊緊急時,錯誤會十分致命。

圖1
將問題形式化描述:數據集包含輸入圖像\(X\),標簽\(Y\),標簽被通用的混淆因子——背景\(S\)所影響,模型學習了潛在的虛假因果,將\(S\)作為識別\(Y\)的特征。
有一種方式是通過因果干預來減輕混淆偏差。例如收集bird
類別在所有場景下的圖像。這樣模型就會只關注於object本身。然而在實際應用中這種方式消耗人力比較大。
在實際應用中,不可能找到某個類別在所有場景下的,如🐟在天空中就不容易找到。而從技術上講,這種方式違背了混淆的正性原則。因此需要將多個類別混合在一起(如圖2中混合地面和水)。

圖2
然而這種粗背景的划分方式會導致過度調整(over-adjustment)的問題,這種干預方式不僅會移除背景,還會損傷原本object中有用的特征,圖2展示了一個實例,鳥的翅膀在天空的背景下都是打開的,而在陸地等背景下是折疊的,因此陸地背景其實損壞了翅膀這個特征。圖中的split 4不僅表達了“天空”和“草地”,也表達了“翅膀”這個信息。將這個問題稱為“不合適的因果干預”(improper causal intervention)。
在本文中,作者提出一種因果注意力模型——CaaM,迭代生成數據的每個部分,並且逐漸地自行標注(self-annotates)混淆因子,克服over-adjustment問題。與更粗的上下文相比,多個CaaM分區粒度更細,更准確地描述了全面的混雜。如圖2左下所示,最后\(\mathcal{T}_N\)的每一個split都包含展開翅膀的圖像,翅膀特征就不再和背景具有相關性了。從技術上講,除了注意力機制試圖學習因果特征,CaaM還具有一種互補的注意,故意捕捉混淆效應(如背景)。兩個解糾纏注意以對抗極小極大的方式進行優化,它們逐漸構成混雜集,並以無監督的方式對混雜偏置進行控制。
CaaM:Causal Attention Module
因果知識
有偏情境下的因果視圖

圖3
在有偏差的情況下,因果圖如圖3(a)所示,\(X\)為輸入圖像,\(Y\)為標簽,\(S\)混淆因子,\(M\)中介變量。
- \(X\rightarrow Y\)表示圖像內容對\(Y\)的直接影響。
- \(X \leftarrow S \rightarrow Y\),在這里\(S\)不再單純表示背景,而是圖片的上下文混淆,決定圖像主題和背景如何在\(X\)上面布局,因此\(S\)決定\(X\),而這種布局和背景不可避免地會影響標簽\(Y\)。
- \(X \rightarrow M \rightarrow Y\)表示圖像中的有用因果特征,\(M\)表示的特征不會隨着域的遷移而產生分布的變化。雖然\(X \rightarrow M \rightarrow Y\)可以被隱藏在\(X \rightarrow Y\)的路徑中,但是為了方便推導,我們還是將其分離出來。
基於數據分塊的干預
數據分塊(data partition)是一種進行因果干預的有效方法。首先它將原始硬分割成\(\mathcal{T}=\{t_1, \cdots,t_m\}\),其中每一份都表示一個混淆層,這種方法的效果等同於后門調整:
在每一個split上面訓練相當於模擬\(P(Y|X,t)\)的分布,如圖3(b)所示,它剪斷了\(X \leftarrow S \rightarrow Y\)的后門路徑。然而現有的方法在某些split上只有很少的數據,離公式1的要求差距還比較大。
不合適的因果干預
由於現有的標注方法很難解耦混淆(\(S\))和因果特征(\(M\)),因此如公式1所示的基於上下文的干預很難實現。下面作者展示如何正確地使用因果干預。
假設划分\(\mathcal{T}\)只包含混淆,那么我們可以通過屏蔽\(M\)來減輕\(S\)的影響。1式可以寫作
然而當每個分割\(\mathcal{T}\)
中既包含\(S\)也包含\(M\),那么就會導致\(S\)與\(M\)不獨立。式2演變為式3。
此時\(X \rightarrow M \rightarrow Y\)這條邊就收到了損害,如圖3(c)所示。
原文說的是剪斷這條邊,但是我認為這個說法有點奇怪,應該是\(M\)的部分隨着\(\mathcal{T}\)的划分,因果特征\(M\)的一部分與\(X\)已經獨立了,所以這條邊收到損傷)
訓練流程

訓練流程如圖4所示。為了擴大每個split的大小,作者在每一步挖掘部分\(\mathcal{T}_i\),在\(N\)輪迭代之后,我們可以將1式近似為\(P(Y|do(X)) \approx\sum_i^N \sum_{t \in \mathcal{T}_i} P(Y|X, t)P(t)\)。
為了將混淆因子和中介變量\(M\)分開,我們設計兩個注意力模塊\(A, \overline{A}\),其中\(A\)是為了計算因果特征,而\(\overline{A}\)計算混淆特征,二者角色相反。對抗訓練促進解耦,
然后我們使用\(\overline{A}\)來更新\(\mathcal{T}_i\)。下面介紹訓練損失的詳細內容。
交叉熵損失
這個損失是為了保證\(A\)和\(\overline{A}\)的結合可以捕捉到\(X \rightarrow Y\)的總偏差效應,而不考慮因果或者混淆的影響,否則,他們可能違反圖3(a)中的數據生成機制。
注意這種有偏訓練廣泛應用於無偏模型(沒看懂)。
其中\(\tilde{x}=\mathcal{A}(x) \circ \overline{\mathcal{A}}(x)\),\(\circ\)表示特征相加,\(f\)為線性分類器,\(\mathscr{l}\)為交叉熵損失。
不變損失
這個損失是用來學習\(\mathcal{A}\)的,
它是由式1中的因果干預造成的split不變量,通過不完全混雜分區\(\mathcal{T}_i\)計算:
其中\(t\)是數據分組,\(g\)是用來預測魯棒特征的線性網絡,\(\mathrm{w}\)表示一個虛擬用於計算跨越分割的梯度懲罰的分類器,\(\lambda\)是權值。在推理階段,\(g(\mathcal{A}(x))\)被用於無偏識別。
對抗訓練
訓練過程通過一個最小化游戲(Mini-Game)和一個最大化游戲(Maxi-Game)來分開\(\mathcal{A}\)和\(\overline{\mathcal{A}}\)。
- 最大化游戲提取\(\overline{\mathcal{A}}(x)\)中的混淆特征,來生成數據塊\(\mathcal{T}_i\),因果特征不對最大化產生貢獻。
- 最小化游戲排除\(\mathcal{A}(x)\)中的混淆特征,混淆特征不對最小化產生貢獻。
最小化游戲(Mini-Game)
這是一個\(\mathrm{XE}\)和\(\mathrm{IL}\)的聯合訓練過程,加上一個新的對抗分類器\(h\),\(h\)專門用於研究由\(\overline{\mathcal{A}}(x)\)引起的混淆效應。
最大化游戲(Maxi-Game)
一個好的數據塊更新應該捕捉那些在split中變化的強混淆。
其中\(\mathcal{T}_i(\theta)\)指的是數據塊\(\mathcal{T}_i\)由參數\(\theta \in \mathbb{R}^{K \times m}\),\(K\)為總訓練數據量,而\(m\)是一個划分中的split數量。\(\theta_{p,q}\)指的是第\(p\)個sample屬於第\(q\)個split的概率。
CaaM的實現
作者將所提出的CaaM實現在兩種流行的基於注意力的深度模型上:基於CBAM的CNN、和
Transformer-based T2T-ViT。將結果模型分別稱為CNN-CaaM和vt - caam。為了簡單起見,在本節中,使用\(\mathbf{c}\)和\(\mathbf{s}\)來表示因果和混雜特征(即,\(\mathbf{c} = \mathcal{A}(x)\)和\(\mathbf{s} = \overline{\mathcal{A}}(x)\))。

圖5 基於CBAM的CNN-CaaM和基於T2T-ViT的ViT-CaaM模型結構。對於CNN-CaaM, D-Block被用來從CNN特征\(\mathbf{x}\)中分離因果特征\(\mathbf{c}\)(藍色)和混雜特征\(\mathbf{s}\)(橙色)。
D-Block (Init.)表示第一個D-Block。而M-Block將\(\mathbf{c}\)和\(\mathbf{s}\)與卷積層合並。然后將M-Block和D-Block疊加,逐步細化\(\mathbf{c}\)和\(\mathbf{s}\)。
CNN-CaaM
對於輸入特征\(x\),注意力計算:
其中\(\mathrm{z} \in \mathbb{R}^{w \times h \times c}\),\(\odot\)指的是元素點乘,因此,CaaM注意力表達如下:
模型結構如圖5(a)所示。
D-Block
D-Block是包含CaaM計算的塊,可以生成兩個注意力特征\(\mathbf{c}\)和\(\mathbf{s}\)。在D-Block之前,可以有很多個標准的殘差模塊,\(D-Block^{j+1}\)可以表示為:
- skip connection 是由標准ResNet塊的輸出連接的。
- 在混淆特征\(\mathbf{s}^j\)上移除skip connection,將其與因果特征\(\mathbf{c}^j\)區分開來。
M-Block
在進入D-Block之前,\(\mathbf{c}\)和\(\mathbf{s}\)被輸入進M-Block進行特征融合:
迭代10式和11式生成多層CaaM,在推理階段,最后的因果特征\(\mathbf{c}^{j+M-1}\)作為預測的魯棒特征。
ViT-CaaM
這個模型請參考原文。