Pytorch中的Embedding


有兩個Embedding函數,通常是用前面這一個

ref https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html

torch.nn.Embedding( num_embeddings, embedding_dim, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False, _weight=None, device=None, dtype=None)

  • num_embeddings: size of the directionary of embedding,也就是詞匯表的大小=不同單詞的個數
  • embedding_dim: the size of each embedding vector,也就是embedding向量的維度

一個簡單的lookup table(查找表),用例存儲固定了dictionary 和 size 的embeddings

它將所有的embedding(詞向量)都存起來了,可以通過詞的索引檢索它們

輸入是一個索引列表,輸出是詞向量列表

>>> # an Embedding module containing 10 tensors of size 3
>>> # 有10種不同的單詞,每個單詞表示為一個3維的向量
>>> embedding = nn.Embedding(10, 3)
>>> # a batch of 2 samples of 4 indices each
>>> input = torch.LongTensor([[1,2,4,5],[4,3,2,9]])
>>> embedding(input)
tensor([[[-0.0251, -1.6902,  0.7172],
         [-0.6431,  0.0748,  0.6969],
         [ 1.4970,  1.3448, -0.9685],
         [-0.3677, -2.7265, -0.1685]],

        [[ 1.4970,  1.3448, -0.9685],
         [ 0.4362, -0.4004,  0.9400],
         [-0.6431,  0.0748,  0.6969],
         [ 0.9124, -2.3616,  1.1151]]])

 embedding = nn.Embedding(10, 3),這里embedding是一個表,input是在表中的索引 

 

另一個函數是

ref https://pytorch.org/docs/stable/generated/torch.nn.functional.embedding.html

torch.nn.functional.embedding(input, weight, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False)

  • input: Tensor containing indices into the embedding matrix,即在詞向量矩陣中的索引列表
  • weight: embedding matrix,即詞向量矩陣,行數為最大可能的索引數+1,列數為詞向量的維度
>>> # a batch of 2 samples of 4 indices each
>>> input = torch.tensor([[1,2,4,5],[4,3,2,9]])
>>> # an embedding matrix containing 10 tensors of size 3
>>> embedding_matrix = torch.rand(10, 3)
>>> F.embedding(input, embedding_matrix)
tensor([[[ 0.8490,  0.9625,  0.6753],
         [ 0.9666,  0.7761,  0.6108],
         [ 0.6246,  0.9751,  0.3618],
         [ 0.4161,  0.2419,  0.7383]],

        [[ 0.6246,  0.9751,  0.3618],
         [ 0.0237,  0.7794,  0.0528],
         [ 0.9666,  0.7761,  0.6108],
         [ 0.3385,  0.8612,  0.1867]]])

這個embedding_matrix不用訓練的嗎?直接用隨機數??

 

而事實上,在nn.Embedding的內部實現中,也是調用的F.embedding

 

 

可見embedding matrix也是隨機數,因為torch.Tensor(a, b)是正態分布的隨機數


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM