[論文理解] Mutual Information Neural Estimation


Mutual Information Neural Estimation

互信息定義:

\(I(X;Z) = \int_{X \times Z} log\frac{d\mathbb{P}(XZ)}{d\mathbb{P}(X) \otimes \mathbb{P}(Z)}d\mathbb{P}(XZ)\)

CPC文章里用下面這個公式定義要更加容易理解,都是一樣的:

\[I(x;z) = \sum_{x,z}p(x,z) log \frac{p(x,z)}{p(x)p(z)} \]

互信息越大,表明兩個變量依賴關系越強,互信息越小,表示兩個隨機變量越獨立。

KL散度的對偶問題:

因此根據KL散度和其對偶問題之間的關系我們可以得到:

\[D_{K L}(\mathbb{P} \| \mathbb{Q}) \geq \sup _{T \in \mathcal{F}} \mathbb{E}_{\mathbb{P}}[T]-\log \left(\mathbb{E}_{\mathbb{Q}}\left[e^{T}\right]\right) \]

利用上式優化互信息的下界:

\[I(X ; Z) \geq I_{\Theta}(X, Z) \]

\[I_{\Theta}(X, Z)=\sup _{\theta \in \Theta} \mathbb{E}_{\mathbb{P}_{X Z}}\left[T_{\theta}\right]-\log \left(\mathbb{E}_{\mathbb{P}_{X} \otimes \mathbb{P}_{Z}}\left[e^{T_{\theta}}\right]\right) \]

優化算法:

一般來說z的分布用高斯分布,x和z的分布(marginal distribution)都好采樣;

對於joint distribution,用一個神經網絡來建模,即F(x,z),然后其結果就是joint distribution的采樣了。

代入公式計算即可。

class Mine(nn.Module):
    def __init__(self, input_size=2, hidden_size=100):
        super().__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.fc3 = nn.Linear(hidden_size, 1)
        
    def forward(self, input):
        output = F.elu(self.fc1(input))
        output = F.elu(self.fc2(output))
        output = self.fc3(output)
        return output

def mutual_information(joint, marginal, mine_net):
    t = mine_net(joint)
    et = torch.exp(mine_net(marginal))
    mi_lb = torch.mean(t) - torch.log(torch.mean(et))
    return mi_lb, t, et


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM