seq2seq聊天模型(二)——Scheduled Sampling


使用典型seq2seq模型,得到的結果欠佳,怎么解決

結果欠佳原因在這里

  • 在訓練階段的decoder,是將目標樣本["吃","蘭州","拉面"]作為輸入下一個預測分詞的輸入。
  • 而在預測階段的decoder,是將上一個預測結果,作為下一個預測值的輸入。(注意查看預測多的箭頭)
    這個差異導致了問題的產生,訓練和預測的情景不同。
    在預測的時候,如果上一個詞語預測錯誤,還后面全部都會跟着錯誤,蝴蝶效應。

解決辦法-Scheduled Sampling

修改訓練時decoder的模型
基礎模型只會使用真實lable數據作為輸入, 現在,train-decoder不再一直都是真實的lable數據作為下一個時刻的輸入。
train-decoder時以一個概率P選擇模型自身的輸出作為下一個預測的輸入,以1-p選擇真實標記作為下一個預測的輸入。
Secheduled sampling(計划采樣),即采樣率P在訓練的過程中是變化的。
一開始訓練不充分,先讓P小一些,盡量使用真實的label作為輸入,隨着訓練的進行,將P增大,多采用自身的輸出作為下一個預測的輸入。
隨着訓練的進行,P越來越大大,train-decoder模型最終變來和inference-decoder預測模型一樣,消除了train-decoder與inference-decoder之間的差異

總之:
通過這個scheduled-samping方案,抹平了訓練decoder和預測decoder之間的差異!讓預測結果和訓練時的結果一樣。

tensorflow

tensoflow已經完成了這個模型,直接調用,設定參數可以使用


training_helper = tf.contrib.seq2seq.ScheduledEmbeddingTrainingHelper(
                    inputs=dec_emb_inputs,
                    sequence_length=self.dec_sequence_length + 2,
                    embedding=self.dec_Wemb,
                    sampling_probability=self.sampling_probability,
                    time_major=False,
                    name='training_helper')
                    
                    
self.sampling_probability = tf.placeholder(
                tf.float32,
                shape=[],
                name='sampling_probability')     
 
# 下面這個時feed_dic
# 隨着epoch的增大,sampling_probability_list逐漸變為1,即全部采用自身輸出作為下個輸入, 
sampling_probability_list = np.linspace(
        start=0.0,
        stop=1.0,
        num=n_epoch,
        dtype=np.float32)
        
                    

實際結果

效果很好


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM