Virtual Adversarial Training: a Regularization Method for Supervised and Semi-supervised Learning
簡介
本文是17年半監督學習的一篇文章,受對抗訓練的啟發,將對抗訓練的范式用於提升半監督學習,並且取得了非常好的效果。不同於最近一直比較火的對比學習,這些稍微“傳統”一點的方法我覺得還是有一定研究價值的,對比學習利用的增廣還是太多利用了人類的先驗,並不普適。
Intution
我們知道,大多數場景下,神經網絡的輸入都是連續的,那么如果我們能讓神經網絡平滑(對x的領域內的輸入有相似的輸出),那么就可以保證相似的輸入通過同一神經網絡得到相似的輸出,基於這樣的想法,那么自然就可以給沒有標簽的樣本一個與之輸入相近的偽標簽,這一算法稱之為label propagation。然而這樣做並不總是work,因為最近有大量的工作表明,神經網絡很容易收到輸入微小變動的攻擊,即輸入微小變動一點,輸出天差萬別。對抗樣本的生成會使得網絡遭到攻擊,從而上面讓網絡平滑的想法就無法實現。於是本文就想,能否利用生成對抗樣本的方法,使得輸入微小改變,但是仍然讓改變的樣本和改變之前的樣本的到相似的輸出?(這樣的方法應該被廣泛用於對抗攻擊的訓練當中,但是並沒有人將其用於半監督學習)
當前的半監督算法,一類是通過增廣,保證增廣前后樣本具有相似的輸出,這種可以理解為讓模型“平滑”;另一類是通過生成對抗網絡,生成樣本填充流形的低密度區域,此類方法並不需要模型“平滑”。但是后者往往缺乏合理的解釋,本文主要着手於前者研究。
Method
本文的VAT(Virtual Adversarial Training)方法最初的定義為這樣的Loss:
這里\(x_*\)是有標簽或無標簽的樣本。
上面用於優化兩分布的損失函數可以使用KL散度,而最重要的是如何計算\(r_{qadv}\)
我們先將\(D[q(y|x_*), p(y|x_*+r,\theta)]\)簡寫為\(D(r, x_*, \theta)\),假定\(p(y|x_*,\theta)\)關於\(\theta\)和\(x\)是二階處處可微的,我們知道,當r=0的時候,\(D(r, x_*, \theta)\)必定取得最小值,所以有\(\nabla_rD(r, x_*, \theta) |_{r=0} = 0\).
根據泰勒展開有:
而我們前面對r有有一個約束\(||r||_2 \leq \epsilon\)。根據瑞利熵原理,上式取得最大值時,\(r\)應為最大特征值對應的特征向量:
其中上划線代表將任意一非零向量投影到其對應方向的單位向量。
然后問題就變成了求海森矩陣的特征向量的問題了。
一般的,我們可能在numpy中會直接調用接口來求特征向量,但是獲得海森矩陣的計算還是挺大的,如果能根據一階的梯度來計算,運算就會小很多,但本文采用了冪迭代法來求解。
算法也就兩步:
Input: matrix H
Output: V main eigenvector of H
initialize V randomly;
repeat {
V <- HV
V <- V / ||V||
} until convergence;
令每次迭代
其中初始的d為隨機向量,最終d將收斂到特征向量u。
因而可以從一個隨機向量d出發,先對海森矩陣H做近似:
因此每一步迭代就變成了
這玩意變成梯度了,因此可以通過pytorch等自動求導到工具來實現了。
最終的Loss就是有監督的loss和上述adv loss加權。
Coding
import torch
import torch.nn as nn
import torch.nn.functional as F
def criterion(pred_p, pred_q):
p = F.softmax(pred_p, dim = 1)
q = F.softmax(pred_q, dim = 1)
return F.kl_div(p, q)
def vat_loss(model, x, iters, ep = 0.1):
model.eval()
pred = model(x)
# 1. 初始化隨機向量
d = torch.rand(x.shape)
d = F.normalize(d)
# 2. 冪迭代
for i in range(iters):
r = ep * d
r.requires_grad = True
d_ = criterion(pred, model(x + r))
d_.backward()
d = F.normalize(r.grad)
model.zero_grad()
model.train()
r_adv = ep * d
loss_adv = criterion(pred, model(x + r_adv))
return loss_adv
class SimpleNet(nn.Module):
def __init__(self):
super(SimpleNet, self).__init__()
def forward(self, x):
return x**2
if __name__ == "__main__":
x = torch.randn(2,3)
net = SimpleNet()
loss = vat_loss(net, x, 3)
print(loss)
結果
在cifar10上復現的結果大致是: