【python學習筆記】pytorch中的nn.Embedding用法


本篇博客參考文章:
通俗講解pytorch中nn.Embedding原理及使用

embedding

詞嵌入,通俗來講就是將文字轉換為一串數字。因為數字是計算機更容易識別的一種表達形式。

我們詞嵌入的過程,就相當於是我們在給計算機制造出一本字典的過程。計算機可以通過這個字典來間接地識別文字。

詞嵌入向量的意思也可以理解成:詞在神經網絡中的向量表示。

詳細可看嵌入(embedding)層的理解

pytorch中的embedding

輸入是一個索引列表,輸出是相應的詞嵌入

torch.nn.Embedding(num_embeddings, embedding_dim, padding_idx=None,
                   max_norm=None,  norm_type=2.0,   scale_grad_by_freq=False, 
                   sparse=False,  _weight=None)

參數

  • num_embeddings(int):詞典的大小尺寸,比如總共出現5000個詞,那就輸入5000。此時index為(0-4999)
  • embedding_dim(int):嵌入向量的維度,即用多少維來表示一個符號。
  • padding_idx(int,optional):填充id,比如,輸入長度為100,但是每次的句子長度並不一樣,后面就需要用統一的數字填充,而這里就是指定這個數字,這樣,網絡在遇到填充id時,就不會計算其與其它符號的相關性。(初始化為0)
  • max_norm(float,optional):最大范數,如果嵌入向量的范數超過了這個界限,就要進行再歸一化。
  • norm_type(float,optional):指定利用什么范數計算,並用於對比max_norm,默認為2范數。
  • scale_grad_by_freq(boolean ,可選):根據單詞在mini-batch中出現的頻率,對梯度進行放縮。默認為False.
  • sparse(bool,可選):若為True,則與權重矩陣相關的梯度轉變為稀疏張量,默認為False。

使用例子

import torch
import numpy as np

#建立詞向量層,詞數為13,嵌入向量維數設為3
embed = torch.nn.Embedding(13,3)
#句子對['I am a boy.','How are you?','I am very lucky.']
#batch = [['i','am','a','boy','.'],['i','am','very','lucky','.'],['how','are','you','?']]

#將batch中的單詞詞典化,用index表示每個詞(先按照這幾個此創建詞典)
#batch = [[2,3,4,5,6],[2,3,7,8,6],[9,10,11,12]]

#每個句子實際長度
#lens = [5,5,4]

#加上EOS標志且index=0
#batch = [[2,3,4,5,6,0],[2,3,7,8,6,0],[9,10,11,12,0]]

#每個句子實際長度(末端加上EOS)
lens = [6,6,5]

#PAD過后,PAD標識的index=1
batch = [[2,3,4,5,6,0],[2,3,7,8,6,0],[9,10,11,12,0,1]]

#RNN的每一步要輸入每個樣例的一個單詞,一次輸入batch_size個樣例
#所以batch要按list外層是時間步數(即序列長度),list內層是batch_size排列。
#即[seq_len,batch_size]
batch = np.transpose(batch)

batch=torch.LongTensor(batch)

embed_batch = embed(batch)

print(embed_batch)
tensor([[[ 0.4582,  0.1676,  0.4495],
         [ 0.4582,  0.1676,  0.4495],
         [ 0.0691, -0.4414, -1.1965]],

        [[-1.0109,  0.7178,  0.0478],
         [-1.0109,  0.7178,  0.0478],
         [ 1.0389, -0.1143,  0.9865]],

        [[ 0.4041,  0.8421, -1.1829],
         [-0.6804,  1.7318,  0.4238],
         [-0.3201, -0.5068,  0.0071]],

        [[ 0.1110, -0.0441, -0.3261],
         [-0.1142, -2.5226,  0.6788],
         [ 0.2379,  1.5004,  0.4553]],

        [[ 1.8359, -1.2531,  1.2757],
         [ 1.8359, -1.2531,  1.2757],
         [-1.4669,  0.1150, -0.7636]],

        [[-1.4669,  0.1150, -0.7636],
         [-1.4669,  0.1150, -0.7636],
         [-1.9697,  0.3393,  0.0089]]], grad_fn=<EmbeddingBackward>)


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM