DAGs with NO TEARS: Continuous Optimization for Structure Learning
概
有向圖可以用鄰接矩陣\(A \in \{0, 1\}^{d \times d}\)來表示, 其中\(A_{ij} = 1\) 表示 node \(i\) 指向 node \(j\). 進一步的, 我們想要表示有向無環圖(DAG), 則\(A\)需要滿足額外的性質, 保證無環.
現在的問題是, 有一堆觀測數據\(X \in \mathbb{R}^{n \times d}\), 如何通過這些數據推測其(特征之間的)關系, 即對應的\(A\).
主要內容
首先, 假設特征之間滿足一個線性關系:
其中
\(z\)為隨機的噪聲.
通過\(W\)可以推出相應的\(A=\mathcal{A}(W)\), 即
故我們目標通常是:
其中\(\mathbb{D}\)表示有向無環圖.
進一步地, 因為我們希望\(W\)是一個系數的矩陣(否則斷然不是DAG), 故
並
顯然現在的關鍵是如何處理\(\mathcal{A}(W) \in \mathbb{D}\)這個條件, 以前的方法通常需要復雜的運算, 本文提出一種等價的條件
滿足
- \(h(W)= 0\)當且僅當\(\mathcal{A}(W) \in \mathbb{D}\);
- \(h(W)\)越小, 說明\(\mathcal{A}(W)\)越接近無環圖;
- \(h(W)\)是一個光滑函數;
- \(h(W)\)便於求導.
顯然1是期望的, 2可以用於判斷所得的\(W\)的優劣, 3, 4便於我們用數值方法求解.
等價條件的推導
\(\mathrm{tr}(I-W)^{-1} = d\)
Proposition 1: 假設\(W \in \mathbb{R}_+^{d \times d}\) 且 \(\|W\| < 1\), 則\(\mathcal{A}(W)\)能夠表示有向無環圖當且僅當
proof:
\(A = \mathcal{A}(W)\)能夠表示有向無環圖, 當且僅當
\(\Rightarrow\)
由於\(\|W\| < 1\)(最大奇異值小於1), 故
\(\Leftarrow\)
\(\mathrm{tr}(W^k) \ge 0\), 故
當且僅當
注: \(\|W\| < 1\)這個條件並不容易滿足.
\(\mathrm{tr}(e^W)=d\)
注: \(e^A = I + \sum_{k=1} \frac{A^k}{k!}\).
Proposition 2: 假設\(W \in \mathbb{R}_+^{d \times d}\), 則\(\mathcal{A}(W)\)能夠表示有向無環圖當且僅當
proof:
證明是類似的.
注: 此時對\(W\)的最大奇異值沒有要求.
\(\mathrm{tr}(W^k) = 0\)
這部分的證明可能應該歸屬於DAG-GNN.
Proposition 3: 假設\(W \in \mathbb{R}_+^{d \times d}\) , 則\(\mathcal{A}(W)\)能夠表示有向無環圖當且僅當
proof:
\(\Rightarrow\)是顯然的, 證明\(\Rightarrow\)只需說明
假設\(W\)的特征多項式為\(p(\lambda) = \sum_{k=0}^d \beta_k \lambda^k, \beta_d=1\), 則有
進一步有
由歸納假設可知結論成立.
Corollary 1: 假設\(W \in \mathbb{R}_+^{d \times d}\) , 則\(\mathcal{A}(W)\)能夠表示有向無環圖當且僅當
\(\mathrm{tr}(e^{W \circ W}) =d\)
注: \(\circ\) 表示哈達瑪積, 即對應元素相乘.
上面依然要求\(W\)各元素大於0, 一個好的辦法是:
Theorem 1: 一個矩陣\(W \in \mathbb{R}^{d \times d}\), 則\(\mathcal{A}(W)\) 能表示有向無環圖當且僅當
proof:
\(\mathcal{A}(W)=\mathcal{A}(W \circ W)\).
\(\mathrm{tr}(I + W \circ W)^d =d\)
Theorem 2: 一個矩陣\(W \in \mathbb{R}^{d \times d}\), 則\(\mathcal{A}(W)\) 能表示有向無環圖當且僅當
注: \(W \circ W\)前面加個系數也是沒關系的.
性質的推導
故, 此時我們只需設置
顯然滿足1,2,3, 接下來我們推導其梯度
故
注: 其中\(M =W \circ W\).
求解
利用augmented Lagrangian轉換為(這一塊不是很懂, 但只是數值求解的東西, 不影響理解)
具體求解算法如下: