torch.nn.Embedding


PyTorch快速入門教程七(RNN做自然語言處理) - pytorch中文網
原文出處: https://ptorch.com/news/11.html

在pytorch里面實現word embedding是通過一個函數來實現的:nn.Embedding

# -*- coding: utf-8 -*-
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

word_to_ix = {'hello': 0, 'world': 1}
embeds = nn.Embedding(2, 5)
hello_idx = torch.LongTensor([word_to_ix['hello']])
hello_idx = Variable(hello_idx)
hello_embed = embeds(hello_idx)
print(hello_embed)

 這就是我們輸出的hello這個詞的word embedding,代碼會輸出如下內容,接下來我們解析一下代碼:

Variable containing:
 0.4606  0.6847 -1.9592  0.9434  0.2316
[torch.FloatTensor of size 1x5]

 

首先我們需要word_to_ix = {'hello': 0, 'world': 1},每個單詞我們需要用一個數字去表示他,這樣我們需要hello的時候,就用0來表示它。

接着就是word embedding的定義nn.Embedding(2, 5),這里的2表示有2個詞,5表示5維度,其實也就是一個2x5的矩陣,所以如果你有1000個詞,每個詞希望是100維,你就可以這樣建立一個word embeddingnn.Embedding(1000, 100)。如何訪問每一個詞的詞向量是下面兩行的代碼,注意這里的詞向量的建立只是初始的詞向量,並沒有經過任何修改優化,我們需要建立神經網絡通過learning的辦法修改word embedding里面的參數使得word embedding每一個詞向量能夠表示每一個不同的詞。

hello_idx = torch.LongTensor([word_to_ix['hello']])
hello_idx = Variable(hello_idx)

 

接着這兩行代碼表示得到一個Variable,它的值是hello這個詞的index,也就是0。這里要特別注意一下我們需要Variable,因為我們需要訪問nn.Embedding里面定義的元素,並且word embeding算是神經網絡里面的參數,所以我們需要定義Variable

hello_embed = embeds(hello_idx)這一行表示得到word embedding里面關於hello這個詞的初始詞向量,最后我們就可以print出來。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM