Siamese Network簡介
Siamese Network 是一種神經網絡的框架,而不是具體的某種網絡,就像seq2seq一樣,具體實現上可以使用RNN也可以使用CNN。
簡單的說,Siamese Network用於評估兩個輸入樣本的相似度。網絡的框架如下圖所示
Siamese Network有兩個結構相同,且共享權值的子網絡。分別接收兩個輸入X1X1與X2X2,將其轉換為向量Gw(X1)Gw(X1)與Gw(X2)Gw(X2),再通過某種距離度量的方式計算兩個輸出向量的距離EwEw。
訓練Siamese Network采用的訓練樣本是一個tuple (X1,X2,y)(X1,X2,y),標簽y=0y=0表示X1X1與X2X2屬於不同類型(不相似、不重復、根據應用場景而定)。y=1y=1則表示X2X2與X2X2屬於相同類型(相似)。
LOSS函數的設計應該是
1. 當兩個輸入樣本不相似(y=0y=0)時,距離EwEw越大,損失越小,即關於EwEw的單調遞減函數。
2. 當兩個輸入樣本相似(y=1y=1)時,距離EwEw越大,損失越大,即關於EwEw的單調遞增函數。
用L+(X1,X2)L+(X1,X2)表示y=1y=1時的LOSS, L−(X1,X2)L−(X1,X2)表示y=0y=0時的LOSS,則LOSS函數可以寫成如下形式
Siamese Network的基本架構、輸入、輸出以及LOSS函數的設計原則如上文所述,接下來就說一下在NLP的場景,具體的Siamese Network應該如何設計。
LSTM Siamese Network
在文本方面,需要計算兩個文本之間的相似度,或者僅僅判斷是否相似,是否重復的場景也很多。簡單直接的方法可以直接從字面上判斷,使用BOW模型,使用SimHash算法都行。但是有些場景,字面上看可能不相似,但是從語義上看是相似的,這就需要更復雜的模型來捕捉它的語義信息了。
比如Quora就有這方面的需求,問答類型的網站希望同樣的問題只有一個就好,但表述問題的方式可以多種多樣,因此需要能夠捕捉到更多語義上的信息。
將Siamese Network架構中的用於表征X1X1與X2X2的Network更換為LSTM網絡,就可以用於判斷兩個輸入文本是否語義上相似。
Learning Text Similarity with Siamese Recurrent Networks這篇文章介紹了這種網絡的結構,也給出了具體的參數。網絡的結構如下圖所示
這是論文中的截圖,在文本輸入與BILSTM之間還有一個embedding層。
論文中的LSTM Siamese Network用了4層hidden unit size為64的BILSTM,再將每一時刻的輸出取平均作為輸入XX的表征向量,后面再接dim=128的全連接層,得到的兩個向量f(X1)f(X1)與f(X2)f(X2)對應的就是第一部分介紹Siamese Network基本框架中的Gw(X1)Gw(X1)與Gw(X2)Gw(X2)。
這里的相似度EE使用的是余弦相似度,即
所以−1≤E(X1,X2)≤1−1≤E(X1,X2)≤1,與歐氏距離不一樣的是,EcosEcos的值越大,代表距離越近,值越小距離越遠,所以LOSS函數的設計也要與上文所說的相反。即
y=0y=0時,LOSS函數隨着EE單調遞增
y=1y=1時,LOSS函數隨着EE單調遞減
具體的有
總的LOSS函數不變 。mm是設定的閾值,可視化LOSS函數如下
LSTM Siamese Network總結起來就是
1. 將Siamese Network中的Encoder換成BILSTM
2. 將距離的計算改成余弦距離
3. 修改相應的LOSS函數
這個設計上還是有一些可以改進的,比如在BILSTM輸出后,加一個attention,而不是直接average每個時刻的輸出,這樣可以更好的表征輸入的文本。
代碼實現
目前github上有一個開源實現,deep-siamese-text-similarity,但是代碼稍微有點亂,並且有些地方實現的不對。
比如BILSTM模型的定義中
outputs, _, _ = tf.nn.bidirectional_rnn(lstm_fw_cell_m, lstm_bw_cell_m, x, dtype=tf.float32) return outputs[-1]
- 1
- 2
- 3
將最后一個時刻的輸出作為表征向量,這樣就忽略了其它時刻的輸出。
還有定義兩個孿生網絡的時候,使用了不同的權值,根據Siamese Network的設計,在這里應該是要reuse_variable來共享權值的。
自己實現了一個,也放到github上:https://github.com/THTBSE/siamese-lstm-network。