Factorization Machine Model
如果僅考慮兩個樣本間的交互, 則factorization machine的公式為:
$\hat{y}(\mathbf{x}):=w_0 + \sum_{i=1}^nw_ix_i + \sum_{i=1}^n\sum_{j=i+1}^n<\mathbf{v}_i, \mathbf{v}_j>x_ix_j$
其中的參數為
$w_0 \in \mathcal{R}, \mathbf{w}\in\mathbb{R}^n,\mathbf{V}\in\mathbb{R}^{n\times k}\tag{1}$
$\mathbf{v_i}$是樣本$i$的向量表示, 維度為$k$, 兩個向量的點積越大, 表示這兩個樣本越相似.
2路FM(2-way FM)捕獲了樣本自身以及樣本之間的交互, 詳解如下
$w_0$是全局偏置
$w_i$是第$i$個樣本的強度
$\hat{w}_{i,j}:=<\mathbf{v}_i, \mathbf{v}_j>$代表第$i$個樣本和第$j$個樣本的交互. 與其為每個樣本對都設置一個參數$w_{i,j}$, FM模型將其分解成兩個向量之間的乘積.
通常來說, 對於任一正定矩陣$\mathbf{W}$, 只要$k$充分大, 都可以找到一個矩陣$\mathbf{V}$使得 $\mathbf{W}= \mathbf{V} \cdot \mathbf{V}^t$. 然而如果數據比較稀疏, 因為數據量不夠估計復雜的交互矩陣$\mathbf{W}$, 通常需要選擇小一點的$k$. 而FM把這種交互分解后, 會學習的更好, 因為FM通過分解來打破了交互之間的依賴性, 減少了參數. 下圖是一個用於預測用戶對電影打分的數據集:
易知$(1)$式的計算復雜度為$\mathit{O}(kn^2)$, 但是其可以做如下化簡:
$\sum_{i=1}^n\sum_{j=i+1}^n<\mathbf{v}_i, \mathbf{v}_j>x_ix_j$
$=\frac{1}{2}\sum_{i=1}^n\sum_{j=1}^n<\mathbf{v}_i,\mathbf{v}_j>x_ix_j - \frac{1}{2}\sum_{i=1^n}<\mathbf{v}_i, \mathbf{v}_j>x_ix_j$
$=\frac{1}{2}\left(\sum_{i=1}^n\sum_{j=1}^n\sum_{f=1}^kv_{i, f}v_{j, f}x_ix_j - \sum_{i=1}^n\sum_{f=1}^kv_{i,f}v_{i,f}x_ix_i\right)$
$=\frac{1}{2}\sum_{f=1}^k\left(\left(\sum_{i=1}^nv_{i, f}x_i\right)\left(\sum_{j=1}^nv_{j,f}x_j\right) - \sum_{i=1}^nv_{i, f}^2x_i^2\right)$
$=\frac{1}{2}\sum_{f=1}^k\left(\left(\sum_{i=1}^nv_{i, f}x_i\right)^2 -\sum_{i=1}^nv_{i, f}^2x_i^2\right)$
根據上述化簡, $(1)$式的計算復雜度可以變為$\mathit{O}(kn)$
FM可以用作回歸, 二分類以及排序. 為了防止過擬合, 最好添加$\mathcal{L}_2$正則化項.
- 回歸 直接使用MSE作為Loss
- 二分類 使用hinge loss或者logit loss.
- 排序 對樣本對$(\mathbf{x}^{(a)}, \mathbf{x}^{(b)})$進行優化, 使用pairwise的分類loss
模型學習
FM的參數$(w_o, \mathbf{w}, \mathbf{V})$可以通過梯度下降方法來學習, 比如SGD.
$\frac{\partial}{\partial \theta}=\begin{cases} 1 & if \hspace{2 pt}\theta \hspace{2 pt}is \hspace{2 pt}w_0 \\ x_i, & if \hspace{2 pt}\theta \hspace{2 pt}is \hspace{2 pt}w_i \\ x_i\sum_{j=1}^nv_{j, f}x_j - v_{i, f}x_i^2, & if \hspace{2 pt}\theta \hspace{2 pt}is\hspace{2 pt} v_{i, f}\end{cases}$
其中$\sum_{j=1}^nv_{j, f}x_j$獨立於$i$, 可以提前計算. 所以所有的梯度都可以在$\mathit{O}(1)$時間內計算得到, 而每個樣本的參數更新可以在$\mathit{O}(kn)$內完成.
2路FM可以擴展到k路:
$\hat{y}(x):=w_0 + \sum_{i=1}^nw_ix_i + \sum_{l=2}^d\sum_{i_1=1}^n\dots\sum_{i_l=i_{l-1}+1}^n\left(\prod_{j=1}^lx_{i_{j}}\right) \left(\sum_{f=1}^{k_l}\prod_{j=1}^lv_{i_j, f}^{(l)}\right)$
參考文獻:
[1]. Factorization Machines. Steffen Rendle. ICDM 2010.