nn.Embedding()函數理解


torch.nn.Embedding(num_embeddings, embedding_dim, padding_idx=None,
 max_norm=None,  norm_type=2.0,   scale_grad_by_freq=False, 
 sparse=False,  _weight=None)

參數解釋:

  • num_embeddings (python:int) – 詞典的大小尺寸,比如總共出現5000個詞,那就輸入5000。此時index為(0-4999)
  • embedding_dim (python:int) – 嵌入向量的維度,即用多少維來表示一個符號。
  • padding_idx (python:int, optional) – 填充id,比如,輸入長度為100,但是每次的句子長度並不一樣,后面就需要用統一的數字填充,而這里就是指定這個數字,這樣,網絡在遇到填充id時,就不會計算其與其它符號的相關性。(初始化為0)
  • max_norm (python:float, optional) – 最大范數,如果嵌入向量的范數超過了這個界限,就要進行再歸一化。
  • norm_type (python:float, optional) – 指定利用什么范數計算,並用於對比max_norm,默認為2范數。
  • scale_grad_by_freq (boolean, optional) – 根據單詞在mini-batch中出現的頻率,對梯度進行放縮。默認為False

nn.Embedding()就是隨機初始化了一個num_embeddings*embedding_dim的二維表,每一行代表着對應索引的詞向量表示。加入我們要得到一句話的初始化詞向量,我們需要將

句子進行分詞,得到每一個詞的索引,將索引送入nn.embedding()函數中,會自動在建立的二維表中找到索引對應的初始化詞向量。

 

nn.Embedding()的輸入是(batch_size,seq_len) 輸出是(batch_size,seq_len,embedding)

這個函數實質上是將索引轉成one-hot向量,之后再與權重矩陣W相乘進行運算,再反向傳播過程中,不斷更新權重W,使得詞向量能更准確的表示這個詞。

參考連接:https://www.jianshu.com/p/63e7acc5e890


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM