Pointer Networks


 

 

 原文鏈接:https://arxiv.org/abs/1506.03134

 

Motivation

現有的序列化預測通常使用RNN。RNN的問題在於輸出數量固定,對於答案長度動態變化的問題並不適用。

作者以凸包問題(Convex Hull)為例。給定一定數量的點,希望找到一系列點組成凸多邊形,使得任一點或者在多邊形內部,或者為多邊形頂點。

 

 

 

顯然對於固定總數的點,其凸包頂點數量可變,采用seq2seq的序列模型難以實現比較好的效果。

 

 

 

因此,作者提出了指針網絡(Pointer Networks),以解決答案長度可變的問題,通過生成由輸入到輸出的答案指針,實現答案從輸入的拷貝。 

 

Review:seq2seq Model

RNN序列模型如圖所示。輸入一系列序列,在模型前半部分進行編碼,后半部分進行解碼。解碼部分的每一步依據模型參數、輸入序列、此前所有步的輸出生成當前步的輸出。

在給定的輸入序列P和RNN模型參數θ的情況下,模型預測一系列答案CP(C1 ~ Cm(P))的概率可以用以下的條件概率乘積表示:

 

 其中模型參數θ可以通過最大化正確答案概率(對數和)來調整:

 

 

Review:Attention

注意力機制如圖所示。當前步的輸出通過對每一個輸入(經過encode編碼)加權求和產生。

 

 

對於第i步的輸出,首先對於每一步輸入編碼計算權重uji,將該步編碼ej與第i步解碼di分別進行線性變換並相加,通過tanh激活並乘以參數v,最后得到uji。

將uji經過softmax得到aji,作為最終權重。以aji作為權重對此前每一步的編碼ej進行加權平均,得到第i步的輸出。

 

 

Pointer Networks

作者提出指針網絡,通過將輸入直接拷貝作為輸出,實現對可變長度答案的預測。

首先使用與Attention同樣的方法,得到每一步的權重。

 

 然后直接將權重最大的輸入項作為輸出,即將softmax后的權重(Attention中的aji作為預測為第j個輸入的概率。

 

直觀來講,就是將Attention中的權重作為指針,將輸入指向輸出,生成針對輸入序列的概率分布。 

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM