深度學習與Pytorch入門實戰(十六)情感分類實戰(基於IMDB數據集)


筆記摘抄

提前安裝torchtext和scapy,運行下面語句(壓縮包地址鏈接:https://pan.baidu.com/s/1_syic9B-SXKQvkvHlEf78w 提取碼:ahh3):

pip install torchtext

pip install scapy

pip install 你的地址\en_core_web_md-2.2.5.tar.gz  
  • 在torchtext中使用spacy時,由於field的默認屬性是tokenizer_language='en'

  • 當使用 en_core_web_md 時要改 field.py文件中 創建的field屬性為tokenizer_language='en_core_web_md',且data.Field()中的參數也要改為tokenizer_language='en_core_web_md'

1. 加載數據

分類任務中,我們所需要接觸到的數據有文本字符串和兩種情感,"pos"或者"neg"。

  • Field的參數制定了數據會被怎樣處理。

  • 我們使用TEXT field來定義如何處理電影評論,使用LABEL field來處理兩個情感類別。

  • 我們的TEXT field帶有tokenize='spacy',這表示我們會用spaCy tokenizer來tokenize英文句子。如果我們不特別聲明tokenize這個參數,那么默認的分詞方法是使用空格。

  • 安裝spaCy

pip install -U spacy
python -m spacy download en

1.1 分割訓練集測試集

import numpy as np
import torch
from torch import nn, optim
from torchtext import data, datasets

# 為CPU設置隨機種子
torch.manual_seed(123)

# 兩個Field對象定義字段的處理方法(文本字段、標簽字段)
TEXT = data.Field(tokenize='spacy', tokenizer_language='en_core_web_md')  # 分詞
LABEL = data.LabelField(dtype=torch.float)
  • TorchText支持很多常見的自然語言處理數據集。

  • 下面的代碼會自動下載IMDb數據集,然后分成train/test兩個torchtext.datasets類別。數據被前面的Fields處理。IMDb數據集一共有50000電影評論,每個評論都被標注為正面的或負面的。

# from torchtext import data, datasets
# IMDB共50000影評,包含正面和負面兩個類別。數據被前面的Field處理
# 按照(TEXT, LABEL) 分割成 訓練集,測試集
train_data, test_data = datasets.IMDB.splits(TEXT, LABEL)

print('len of train data:', len(train_data))        # 25000
print('len of test data:', len(test_data))          # 25000

# torchtext.data.Example : 用來表示一個樣本,數據+標簽
print(train_data.examples[15].text)                 # 文本:句子的單詞列表
print(train_data.examples[15].label)                # 標簽: 積極
len of train data: 25000
len of test data: 25000
['Like', 'one', 'of', 'the', 'previous', 'commenters', 'said', ',', 'this', 'had', 'the', 'foundations', 'of', 'a', 'great', 'movie', 'but', 'something', 'happened', 'on', 'the', 'way', 'to', 'delivery', '.', 'Such', 'a', 'waste', 'because', 'Collette', "'s", 'performance', 'was', 'eerie', 'and', 'Williams', 'was', 'believable', '.', 'I', 'just', 'kept', 'waiting', 'for', 'it', 'to', 'get', 'better', '.', 'I', 'do', "n't", 'think', 'it', 'was', 'bad', 'editing', 'or', 'needed', 'another', 'director', ',', 'it', 'could', 'have', 'just', 'been', 'the', 'film', '.', 'It', 'came', 'across', 'as', 'a', 'Canadian', 'movie', ',', 'something', 'like', 'the', 'first', 'few', 'seasons', 'of', 'X', '-', 'Files', '.', 'Not', 'cheap', ',', 'just', 'hokey', '.', 'Also', ',', 'it', 'needed', 'a', 'little', 'more', 'suspense', '.', 'Something', 'that', 'makes', 'you', 'jump', 'off', 'your', 'seat', '.', 'The', 'movie', 'reached', 'that', 'moment', 'then', 'faded', 'away', ';', 'kind', 'of', 'like', 'a', 'false', 'climax', '.', 'I', 'can', 'see', 'how', 'being', 'too', 'suspenseful', 'would', 'have', 'taken', 'away', 'from', 'the', '"', 'reality', '"', 'of', 'the', 'story', 'but', 'I', 'thought', 'that', 'part', 'was', 'reached', 'when', 'Gabriel', 'was', 'in', 'the', 'hospital', 'looking', 'for', 'the', 'boy', '.', 'This', 'movie', 'needs', 'to', 'have', 'a', 'Director', "'s", 'cut', 'that', 'tries', 'to', 'fix', 'these', 'problems', '.']
pos
  • 由於我們現在只有train/test這兩個分類,所以我們需要創建一個新的validation set。我們可以使用.split()創建新的分類。

  • 默認的數據分割是 70、30,如果我們聲明split_ratio,可以改變split之間的比例,split_ratio=0.8表示80%的數據是訓練集,20%是驗證集。

  • 我們還聲明random_state這個參數,確保我們每次分割的數據集都是一樣的。

import random
SEED = 1234
train_data, valid_data = train_data.split(random_state=random.seed(SEED))

檢查一下現在每個部分有多少條數據。

print(f'Number of training examples: {len(train_data)}')
print(f'Number of validation examples: {len(valid_data)}')
print(f'Number of testing examples: {len(test_data)}')
Number of training examples: 17500
Number of validation examples: 7500
Number of testing examples: 25000

1.2 創建vocabulary

  • vocabulary把每個單詞一一映射到一個數字。

  • 使用10k個單詞來構建單詞表(用max_size這個參數可以設定)

  • 所有其他的單詞都用<unk>來表示。

  • 詞典中應當有10002個單詞,且有兩個label,可以通過TEXT.vocabTEXT.label查詢,可以直接用stoi(stringtoint) 或者 itos(inttostring) 來查看單詞表。

TEXT.build_vocab(train_data, max_size=10000, vectors='glove.6B.100d') # unk_init=torch.Tensor.normal_
LABEL.build_vocab(train_data)

print(len(TEXT.vocab))             # 10002
print(TEXT.vocab.itos[:12])        # ['<unk>', '<pad>', 'the', ',', '.', 'and', 'a', 'of', 'to', 'is', 'in', 'I']
print(TEXT.vocab.stoi['and'])      # 5
print(LABEL.vocab.stoi)            # defaultdict(None, {'neg': 0, 'pos': 1})
['<unk>', '<pad>', 'the', ',', '.', 'and', 'a', 'of', 'to', 'is', 'in', 'I']
5
defaultdict(<function _default_unk_index at 0x7f44ffd3ed90>, {'neg': 0, 'pos': 1})
  • 當我們把句子傳進模型的時候,我們是按照一個個 batch 穿進去的。

  • 也就是說,我們一次傳入了好幾個句子,而且每個batch中的句子必須是相同的長度。為了確保句子的長度相同,TorchText會把短的句子pad到和最長的句子等長。

  • 下面我們來看看訓練數據集中最常見的單詞。

print(TEXT.vocab.freqs.most_common(20))
[('the', 201455), (',', 192552), ('.', 164402), ('a', 108963), ('and', 108649), ('of', 100010), ('to', 92873), ('is', 76046), ('in', 60904), ('I', 54486), ('it', 53405), ('that', 49155), ('"', 43890), ("'s", 43151), ('this', 42454), ('-', 36769), ('/><br', 35511), ('was', 34990), ('as', 30324), ('with', 29691)]

1.3 創建iteratiors

  • 每個itartion都會返回一個batch的examples。

  • 每個iterator中各有兩部分:詞(.text)和標簽(.label),其中 text 全部轉換成數字了

  • BucketIterator會把長度差不多的句子放到同一個batch中,確保每個batch中不出現太多的padding。

  • 這里因為pad比較少,所以把 也當做了模型的輸入進行訓練。

  • 如果有GPU,還可以指定每個iteration返回的tensor 都在GPU上。

batchsz = 30

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
train_iterator, valid_iterator, test_iterator = data.BucketIterator.splits(
                                (train_data, valid_data, test_data),
                                batch_size = batchsz,
                                device = device,
                                repeat = False
                               )
# for i, _ in enumerate(train_iterator):
#     print(i)
batch = next(iter(train_iterator))
print(batch.text)
print(batch.text.shape)
print(batch.label.shape)
tensor([[  25,   66, 1215,  ...,  471,   11, 1267],
        [ 132,    9, 2348,  ...,   42,  465,  298],
        [  19,    6, 1703,  ...,    3,  142, 1678],
        ...,
        [   1,    1,    1,  ...,    1,    1,    1],
        [   1,    1,    1,  ...,    1,    1,    1],
        [   1,    1,    1,  ...,    1,    1,    1]])
torch.Size([1058, 64])   # 一個batch,64條數據,1058個參數
torch.Size([64])
print(TEXT.pad_token)
PAD_IDX = TEXT.vocab.stoi[TEXT.pad_token]
print(PAD_IDX)
mask = batch.text == PAD_IDX
print(mask)
<pad>
1
tensor([[False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        ...,
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True]])

2. 定義模型

class RNN(nn.Module):

  def __init__(self, vocab_size, embedding_dim, hidden_dim):
    super(RNN, self).__init__()

    # [0-10001] => [100]
    # 參數1:embedding個數(單詞數), 參數2:embedding的維度(詞向量維度)
    self.embedding = nn.Embedding(vocab_size, embedding_dim)
    # [100] => [256]
    # 雙向LSTM,所以下面FC層使用 hidden_dim*2
    self.rnn = nn.LSTM(embedding_dim, hidden_dim, num_layers=2,
                       bidirectional=True, dropout=0.5) 
    # [256*2] => [1]
    self.fc = nn.Linear(hidden_dim*2, 1)
    self.dropout = nn.Dropout(0.5)

  def forward(self, x):
    """
    x: [seq_len, b] vs [b, 3, 28, 28]
    """
    # [seq_len, b, 1] => [seq_len, b, 100]
    embedding = self.dropout(self.embedding(x))

    # output: [seq, b, hid_dim*2]
    # hidden/h: [num_layers*2, b, hid_dim]
    # cell/c: [num_layers*2, b, hid_dim]
    output, (hidden, cell) = self.rnn(embedding)
    # [num_layers*2, b, hid_dim] => 2 of [b, hid_dim] => [b, hid_dim*2]
    # 雙向,所以要把最后兩個輸出連接
    hidden = torch.cat([hidden[-2], hidden[-1]], dim=1)
    # [b, hid_dim*2] => [b, 1]
    hidden = self.dropout(hidden)
    out = self.fc(hidden)

    return out
  • 使用 預訓練過的embedding 來替換隨機初始化

  • Tip:.copy_() 這種 帶着下划線的函數 均代表 替換inplace

rnn = RNN(len(TEXT.vocab), 100, 256)                          #詞個數,詞嵌入維度,輸出維度

pretrained_embedding = TEXT.vocab.vectors
print('pretrained_embedding:', pretrained_embedding.shape)    # torch.Size([10002, 100])

# 使用預訓練過的embedding來替換隨機初始化
rnn.embedding.weight.data.copy_(pretrained_embedding)
print('embedding layer inited.')
pretrained_embedding: torch.Size([10002, 100])
embedding layer inited.

3. 訓練模型

  • 首先定義模型和損失函數。
optimizer = optim.Adam(rnn.parameters(), lr=1e-3)

# BCEWithLogitsLoss是針對二分類的CrossEntropy
criteon = nn.BCEWithLogitsLoss()

如果使用GPU加速,改成:

# 優化函數
optimizer = optim.Adam(rnn.parameters(), lr=1e-3)

# BCEWithLogitsLoss是針對二分類的CrossEntropy
criteon = nn.BCEWithLogitsLoss().to(device)

rnn = rnn.to(device)
RNN(
  (embedding): Embedding(10002, 100)
  (rnn): LSTM(100, 256, num_layers=2, dropout=0.5, bidirectional=True)
  (fc): Linear(in_features=512, out_features=1, bias=True)
  (dropout): Dropout(p=0.5, inplace=False)
)
  • 定義一個函數用於計算准確率
def binary_acc(preds, y):

    preds = torch.round(torch.sigmoid(preds))
    correct = torch.eq(preds, y).float()
    acc = correct.sum() / len(correct)
    return acc
  • 定義一個訓練函數
def train(rnn, iterator, optimizer, criteon):
    epoch_loss = 0
    epoch_acc = 0
    avg_acc = []
    rnn.train()   # 表示進入訓練模式

    for i, batch in enumerate(iterator):
        # [seq, b] => [b, 1] => [b]
        # batch.text 就是上面forward函數的參數text,壓縮維度是為了和batch.label維度一致
        pred = rnn(batch.text).squeeze(1)

        loss = criteon(pred, batch.label)
        # 計算每個batch的准確率
        acc = binary_acc(pred, batch.label).item()
        avg_acc.append(acc)

        optimizer.zero_grad()  # 清零梯度准備計算
        loss.backward()        # 反向傳播
        optimizer.step()       # 更新訓練參數

        if i % 10 == 0:
            print(i, acc)
        
        epoch_loss += loss.item()
        epoch_acc += acc.item()
        
    avg_acc = np.array(avg_acc).mean()
    print('avg acc:', avg_acc)
    
    return epoch_loss / len(iterator), epoch_acc / len(iterator)   

4. 評估模型

  • 定義一個評估函數,和訓練函數高度重合

  • 區別是要把rnn.train()改為rnn.val(),不需要反向傳播過程。

def evaluate(rnn, iterator, criteon):
    avg_acc = []
    epoch_loss = 0
    epoch_acc = 0
    rnn.eval()         # 表示進入測試模式

    with torch.no_grad():
        for batch in iterator:
            pred = rnn(batch.text).squeeze(1)      # [b, 1] => [b]
            loss = criteon(pred, batch.label)
            acc = binary_acc(pred, batch.label).item()
            avg_acc.append(acc)

            epoch_loss += loss.item()
            epoch_acc += acc.item()

    avg_acc = np.array(avg_acc).mean()
    print('test acc:', avg_acc)

    return epoch_loss / len(iterator), epoch_acc / len(iterator)

5. 運行


best_valid_loss = float('inf')
for epoch in range(10):
    # 訓練模型
    train_loss, train_acc = train(rnn, train_iterator, optimizer, criteon)
    # 評估模型
    valid_loss, valid_acc = evaluate(rnn, valid_iterator, criteon)

    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), 'wordavg-model.pt')
    
view result
0 0.8666667342185974
10 0.9666666984558105
20 0.8000000715255737
30 0.8666667342185974
40 0.8666667342185974
50 0.8000000715255737
60 0.9333333969116211
70 0.7666667103767395
80 0.9000000357627869
90 0.8666667342185974
100 0.9000000357627869
110 0.7666667103767395
120 0.8000000715255737
130 0.9666666984558105
140 0.8666667342185974
150 0.9000000357627869
160 0.9000000357627869
170 0.9000000357627869
180 0.8000000715255737
190 0.8000000715255737
200 0.9333333969116211
210 0.9000000357627869
220 0.9333333969116211
230 0.8666667342185974
240 0.9000000357627869
250 0.7666667103767395
260 0.9333333969116211
270 0.9000000357627869
280 0.8000000715255737
290 0.8666667342185974
300 0.9333333969116211
310 0.7666667103767395
320 0.9000000357627869
330 0.9666666984558105
340 0.9666666984558105
350 0.8333333730697632
360 0.9000000357627869
370 0.8000000715255737
380 0.9000000357627869
390 0.8666667342185974
400 0.8333333730697632
410 0.9000000357627869
420 0.9333333969116211
430 0.8333333730697632
440 0.8666667342185974
450 0.8000000715255737
460 0.9333333969116211
470 0.8666667342185974
480 0.9333333969116211
490 0.9333333969116211
500 0.9000000357627869
510 0.8333333730697632
520 0.8666667342185974
530 0.9333333969116211
540 0.9333333969116211
550 0.7666667103767395
560 0.8333333730697632
570 0.9333333969116211
580 0.9000000357627869
590 0.9333333969116211
600 0.9000000357627869
610 0.8333333730697632
620 0.7333333492279053
630 0.8333333730697632
640 0.8333333730697632
650 0.9000000357627869
660 0.9333333969116211
670 0.8000000715255737
680 0.9000000357627869
690 0.9000000357627869
700 0.9000000357627869
710 0.9333333969116211
720 0.8000000715255737
730 0.9333333969116211
740 0.9666666984558105
750 0.9666666984558105
760 0.9333333969116211
770 0.8666667342185974
780 0.8666667342185974
790 0.8666667342185974
800 0.9666666984558105
810 0.9000000357627869
820 0.9000000357627869
830 0.9333333969116211
avg acc: 0.8855715916454078
test acc: 0.8775779855051201
0 0.9000000357627869
10 0.9666666984558105
20 0.9000000357627869
30 0.9000000357627869
40 0.9666666984558105
50 0.9666666984558105
60 0.7666667103767395
70 0.8666667342185974
80 0.9333333969116211
90 0.9000000357627869
100 0.9333333969116211
110 0.8666667342185974
120 0.9000000357627869
130 0.9000000357627869
140 0.8666667342185974
150 0.8333333730697632
160 0.8333333730697632
170 0.9333333969116211
180 0.8333333730697632
190 0.9000000357627869
200 0.8666667342185974
210 1.0
220 1.0
230 0.9666666984558105
240 0.9000000357627869
250 0.8000000715255737
260 0.9333333969116211
270 0.9666666984558105
280 0.9333333969116211
290 0.9666666984558105
300 0.9000000357627869
310 0.9333333969116211
320 0.9333333969116211
330 0.9666666984558105
340 0.9666666984558105
350 0.9666666984558105
360 0.9333333969116211
370 0.9666666984558105
380 0.8333333730697632
390 0.7333333492279053
400 0.9000000357627869
410 0.9000000357627869
420 0.8000000715255737
430 0.9333333969116211
440 0.8666667342185974
450 0.9333333969116211
460 0.8333333730697632
470 0.9333333969116211
480 0.9333333969116211
490 0.8000000715255737
500 0.9666666984558105
510 0.9000000357627869
520 1.0
530 0.9666666984558105
540 1.0
550 0.9333333969116211
560 0.9000000357627869
570 1.0
580 0.9000000357627869
590 0.9000000357627869
600 0.8666667342185974
610 0.8333333730697632
620 0.9000000357627869
630 0.9000000357627869
640 0.8666667342185974
650 0.9000000357627869
660 0.9666666984558105
670 0.9333333969116211
680 0.8666667342185974
690 0.9000000357627869
700 0.8666667342185974
710 0.9333333969116211
720 0.9666666984558105
730 0.9666666984558105
740 0.9666666984558105
750 0.9000000357627869
760 0.9000000357627869
770 0.9000000357627869
780 0.9333333969116211
790 0.9333333969116211
800 0.9333333969116211
810 0.8666667342185974
820 0.9000000357627869
830 0.9000000357627869
avg acc: 0.9071942910873633
test acc: 0.8886890964542361
0 0.9333333969116211
10 0.9333333969116211
20 0.9666666984558105
30 0.9333333969116211
40 0.9333333969116211
50 0.8666667342185974
60 1.0
70 0.8333333730697632
80 0.9666666984558105
90 0.9000000357627869
100 0.9666666984558105
110 0.9666666984558105
120 0.9333333969116211
130 0.9333333969116211
140 0.9000000357627869
150 0.9666666984558105
160 0.8666667342185974
170 0.9666666984558105
180 0.9666666984558105
190 0.9333333969116211
200 0.9333333969116211
210 0.8666667342185974
220 0.9000000357627869
230 0.8333333730697632
240 0.9333333969116211
250 0.8000000715255737
260 0.8666667342185974
270 0.9000000357627869
280 0.9000000357627869
290 0.9666666984558105
300 0.9333333969116211
310 0.9000000357627869
320 0.9333333969116211
330 0.9666666984558105
340 0.9000000357627869
350 1.0
360 0.9666666984558105
370 0.9333333969116211
380 0.9333333969116211
390 0.9666666984558105
400 0.9666666984558105
410 0.9666666984558105
420 1.0
430 0.9000000357627869
440 1.0
450 0.9000000357627869
460 0.9333333969116211
470 1.0
480 0.9000000357627869
490 0.9333333969116211
500 0.9000000357627869
510 0.9000000357627869
520 0.9333333969116211
530 0.9333333969116211
540 0.9666666984558105
550 0.9666666984558105
560 0.9666666984558105
570 0.9666666984558105
580 0.8333333730697632
590 0.9666666984558105
600 0.9333333969116211
610 0.9333333969116211
620 0.9333333969116211
630 1.0
640 0.9000000357627869
650 0.8666667342185974
660 0.9333333969116211
670 0.8666667342185974
680 0.9666666984558105
690 0.9333333969116211
700 1.0
710 0.9666666984558105
720 0.9666666984558105
730 0.9000000357627869
740 0.9333333969116211
750 0.9666666984558105
760 1.0
770 0.8666667342185974
780 0.9000000357627869
790 0.9333333969116211
800 0.9666666984558105
810 0.9000000357627869
820 0.9666666984558105
830 0.8000000715255737
avg acc: 0.9266587171337302
test acc: 0.8872902161068768
0 0.9333333969116211
10 1.0
20 1.0
30 0.9666666984558105
40 0.9666666984558105
50 1.0
60 0.9333333969116211
70 0.9666666984558105
80 0.8666667342185974
90 0.9666666984558105
100 0.9333333969116211
110 0.8666667342185974
120 0.9333333969116211
130 0.9000000357627869
140 0.8333333730697632
150 0.9666666984558105
160 0.9666666984558105
170 0.8666667342185974
180 0.9666666984558105
190 0.9666666984558105
200 0.9333333969116211
210 0.9333333969116211
220 0.9666666984558105
230 0.9666666984558105
240 0.9000000357627869
250 1.0
260 0.9333333969116211
270 0.9666666984558105
280 0.9333333969116211
290 0.9000000357627869
300 1.0
310 0.9333333969116211
320 0.9666666984558105
330 0.9666666984558105
340 0.9333333969116211
350 0.9333333969116211
360 0.9333333969116211
370 0.9333333969116211
380 1.0
390 1.0
400 0.9333333969116211
410 1.0
420 0.9333333969116211
430 0.9666666984558105
440 0.9333333969116211
450 0.9333333969116211
460 0.9666666984558105
470 0.8333333730697632
480 1.0
490 0.9333333969116211
500 0.9666666984558105
510 0.9000000357627869
520 0.9000000357627869
530 1.0
540 0.9333333969116211
550 0.9666666984558105
560 0.9000000357627869
570 0.9333333969116211
580 0.9333333969116211
590 0.9666666984558105
600 0.8333333730697632
610 0.9333333969116211
620 0.8666667342185974
630 0.9000000357627869
640 0.9333333969116211
650 0.9666666984558105
660 0.9666666984558105
670 0.9333333969116211
680 0.9333333969116211
690 0.9333333969116211
700 0.9666666984558105
710 0.9000000357627869
720 0.9333333969116211
730 1.0
740 0.9666666984558105
750 0.9333333969116211
760 0.9666666984558105
770 0.8333333730697632
780 0.9666666984558105
790 0.9000000357627869
800 0.9000000357627869
810 0.9000000357627869
820 0.9666666984558105
830 0.9666666984558105
avg acc: 0.9356515197445163
test acc: 0.890008042184569
0 1.0
10 1.0
20 0.9000000357627869
30 0.8666667342185974
40 0.9000000357627869
50 0.9333333969116211
60 0.9000000357627869
70 0.9666666984558105
80 0.8666667342185974
90 0.9000000357627869
100 0.9333333969116211
110 1.0
120 0.9666666984558105
130 0.9666666984558105
140 1.0
150 0.9333333969116211
160 0.9333333969116211
170 0.9333333969116211
180 1.0
190 0.9666666984558105
200 0.9333333969116211
210 1.0
220 0.9666666984558105
230 1.0
240 0.9333333969116211
250 0.8333333730697632
260 0.9666666984558105
270 0.9333333969116211
280 0.9000000357627869
290 1.0
300 0.9666666984558105
310 0.9333333969116211
320 0.9000000357627869
330 0.9000000357627869
340 1.0
350 0.9666666984558105
360 1.0
370 0.9666666984558105
380 0.9000000357627869
390 0.9666666984558105
400 0.9666666984558105
410 0.9333333969116211
420 0.9000000357627869
430 1.0
440 0.9333333969116211
450 0.9666666984558105
460 0.9666666984558105
470 1.0
480 1.0
490 0.9666666984558105
500 1.0
510 1.0
520 1.0
530 1.0
540 0.8666667342185974
550 1.0
560 0.9333333969116211
570 0.9333333969116211
580 0.9666666984558105
590 0.9666666984558105
600 0.9333333969116211
610 0.9000000357627869
620 0.9333333969116211
630 0.9666666984558105
640 0.9666666984558105
650 0.9333333969116211
660 0.9333333969116211
670 0.9000000357627869
680 0.9333333969116211
690 0.9000000357627869
700 0.9333333969116211
710 0.9666666984558105
720 0.9666666984558105
730 0.9333333969116211
740 0.9333333969116211
750 1.0
760 0.9666666984558105
770 0.9333333969116211
780 0.9333333969116211
790 0.9000000357627869
800 1.0
810 0.9000000357627869
820 1.0
830 0.9000000357627869
avg acc: 0.9450040338136595
test acc: 0.8848521674422624
0 1.0
10 1.0
20 0.9666666984558105
30 0.9666666984558105
40 1.0
50 1.0
60 0.9666666984558105
70 1.0
80 0.9666666984558105
100 0.9666666984558105
110 0.9666666984558105
120 0.9333333969116211
130 0.9666666984558105
140 0.9666666984558105
150 1.0
160 0.9666666984558105
170 1.0
180 1.0
190 0.9666666984558105
200 0.8666667342185974
210 1.0
220 0.8666667342185974
230 0.9666666984558105
240 0.9333333969116211
250 0.8333333730697632
260 0.9666666984558105
270 0.9666666984558105
280 0.9000000357627869
290 0.9666666984558105
300 0.9666666984558105
310 0.9333333969116211
320 1.0
330 0.9666666984558105
340 0.9666666984558105
350 0.9333333969116211
360 0.9000000357627869
370 0.8666667342185974
380 0.9333333969116211
390 0.8333333730697632
400 0.9666666984558105
410 1.0
420 0.9666666984558105
430 0.9666666984558105
440 1.0
450 0.9666666984558105
460 0.9333333969116211
470 1.0
480 0.9666666984558105
490 1.0
500 0.9666666984558105
510 0.9333333969116211
520 0.8666667342185974
530 0.9666666984558105
540 1.0
550 1.0
560 0.9333333969116211
570 0.9333333969116211
580 1.0
590 0.9666666984558105
600 0.9666666984558105
610 0.9666666984558105
620 0.9666666984558105
630 0.9666666984558105
640 0.9333333969116211
650 0.9000000357627869
660 0.9333333969116211
670 1.0
680 0.9333333969116211
690 0.9666666984558105
700 0.9333333969116211
710 1.0
720 0.9333333969116211
730 1.0
740 0.9666666984558105
750 0.9666666984558105
760 0.8666667342185974
770 0.9000000357627869
780 0.8000000715255737
790 0.9666666984558105
800 0.9666666984558105
810 0.8666667342185974
820 1.0
830 0.9666666984558105
avg acc: 0.9509592677334802
test acc: 0.8718625588668621
0 1.0
10 1.0
20 0.9666666984558105
30 0.9333333969116211
40 1.0
50 0.9666666984558105
60 0.9666666984558105
70 0.9666666984558105
80 1.0
90 0.9333333969116211
100 1.0
110 0.9666666984558105
120 0.9666666984558105
130 0.9666666984558105
140 0.9666666984558105
150 0.9666666984558105
160 0.9666666984558105
170 1.0
180 0.9666666984558105
190 0.9000000357627869
200 1.0
210 1.0
220 0.9333333969116211
230 1.0
240 0.9666666984558105
250 1.0
260 0.9666666984558105
270 0.9666666984558105
280 0.9333333969116211
290 0.9333333969116211
300 0.9666666984558105
310 0.9666666984558105
320 0.9666666984558105
330 0.9333333969116211
340 1.0
350 0.9333333969116211
360 0.9666666984558105
370 0.9333333969116211
380 0.9666666984558105
390 0.9333333969116211
400 0.9666666984558105
410 0.9666666984558105
420 0.9666666984558105
430 0.9333333969116211
440 0.9333333969116211
450 0.9666666984558105
460 1.0
470 1.0
480 0.9666666984558105
490 0.9333333969116211
500 0.9666666984558105
510 0.9333333969116211
520 0.9666666984558105
530 0.9666666984558105
540 1.0
550 0.9666666984558105
560 0.9333333969116211
570 1.0
580 0.9666666984558105
590 0.9666666984558105
600 1.0
610 0.9000000357627869
620 0.9333333969116211
630 0.9333333969116211
640 0.9333333969116211
650 0.9666666984558105
660 0.9000000357627869
670 0.9000000357627869
680 1.0
690 0.9333333969116211
700 0.9666666984558105
710 0.8000000715255737
720 0.9333333969116211
730 0.8666667342185974
740 0.9333333969116211
750 0.9666666984558105
760 1.0
770 0.9333333969116211
780 0.9000000357627869
790 0.9666666984558105
800 0.9333333969116211
810 0.8666667342185974
820 0.9000000357627869
830 0.9666666984558105
avg acc: 0.9605116213111283
test acc: 0.8822142779827118
0 0.9666666984558105
10 0.9666666984558105
20 1.0
30 0.9666666984558105
40 1.0
50 0.9666666984558105
60 1.0
70 0.9000000357627869
80 1.0
90 0.9666666984558105
100 0.9333333969116211
110 1.0
120 1.0
130 0.9666666984558105
140 0.9666666984558105
150 1.0
160 0.9666666984558105
170 0.9333333969116211
180 0.9666666984558105
190 0.9333333969116211
200 0.9666666984558105
210 1.0
220 0.9666666984558105
230 1.0
240 0.9666666984558105
250 1.0
260 0.9333333969116211
270 0.9666666984558105
280 0.9000000357627869
290 1.0
300 0.9333333969116211
310 0.9666666984558105
320 0.9666666984558105
330 0.9333333969116211
340 1.0
350 0.9333333969116211
360 0.9666666984558105
370 1.0
380 1.0
390 0.9000000357627869
400 1.0
410 1.0
420 1.0
430 1.0
440 1.0
450 0.9666666984558105
460 0.9000000357627869
470 1.0
480 1.0
490 0.8666667342185974
500 1.0
510 1.0
520 1.0
530 0.9666666984558105
540 0.9000000357627869
550 1.0
560 0.9333333969116211
570 0.9666666984558105
580 1.0
590 0.9666666984558105
600 0.9333333969116211
610 0.9666666984558105
620 0.9666666984558105
630 1.0
640 0.9000000357627869
650 0.9666666984558105
660 1.0
670 0.9000000357627869
680 0.9333333969116211
690 1.0
700 1.0
710 1.0
720 0.9666666984558105
730 1.0
740 1.0
750 1.0
760 1.0
770 0.8666667342185974
780 0.9666666984558105
790 0.9333333969116211
800 0.9666666984558105
810 1.0
820 1.0
830 0.9666666984558105
avg acc: 0.9653077817363419
test acc: 0.8769784666222634
0 1.0
10 0.9666666984558105
20 1.0
30 0.9333333969116211
40 1.0
50 1.0
60 1.0
70 1.0
80 0.9666666984558105
90 0.9333333969116211
100 0.9666666984558105
110 0.9666666984558105
120 1.0
130 0.9666666984558105
140 1.0
150 1.0
160 0.9666666984558105
170 1.0
180 0.9333333969116211
190 1.0
200 0.9666666984558105
210 1.0
220 0.8333333730697632
230 1.0
240 1.0
250 0.9666666984558105
260 0.9666666984558105
270 0.9000000357627869
280 0.9666666984558105
290 0.9333333969116211
300 0.9666666984558105
310 0.9666666984558105
320 0.9333333969116211
330 1.0
340 1.0
350 0.9333333969116211
360 0.9666666984558105
370 0.9666666984558105
380 0.9666666984558105
390 1.0
400 0.9333333969116211
410 0.9333333969116211
420 1.0
430 0.9666666984558105
440 0.9666666984558105
450 0.9333333969116211
460 1.0
470 0.9666666984558105
480 1.0
490 1.0
500 0.9333333969116211
510 0.9666666984558105
520 1.0
530 0.9333333969116211
540 0.9666666984558105
550 0.9333333969116211
560 0.9333333969116211
570 0.9333333969116211
580 1.0
590 1.0
600 0.9333333969116211
610 0.9666666984558105
620 1.0
630 1.0
640 1.0
650 0.9666666984558105
660 1.0
670 1.0
680 1.0
690 0.9333333969116211
700 1.0
710 0.9333333969116211
720 1.0
730 1.0
740 0.9666666984558105
750 0.9000000357627869
760 0.9000000357627869
770 0.9333333969116211
780 0.9666666984558105
790 1.0
800 1.0
810 0.9666666984558105
820 0.9666666984558105
830 1.0
avg acc: 0.9697442299885144
test acc: 0.8815348212667506
0 0.9666666984558105
10 0.9666666984558105
20 0.9666666984558105
30 0.9666666984558105
40 0.9666666984558105
50 0.8666667342185974
60 1.0
70 1.0
80 1.0
90 1.0
100 1.0
110 0.9666666984558105
120 1.0
130 1.0
140 1.0
150 0.9666666984558105
160 1.0
170 0.9333333969116211
180 1.0
190 0.9000000357627869
200 1.0
210 0.8666667342185974
220 1.0
230 1.0
240 1.0
250 0.9000000357627869
260 1.0
270 1.0
280 0.9666666984558105
290 0.9666666984558105
300 0.9666666984558105
310 0.9666666984558105
320 0.9666666984558105
330 1.0
340 1.0
350 0.9333333969116211
360 0.9666666984558105
370 1.0
380 0.9666666984558105
390 1.0
400 0.9666666984558105
410 1.0
420 1.0
430 1.0
440 1.0
450 1.0
460 1.0
470 1.0
480 0.9666666984558105
490 1.0
500 1.0
510 1.0
520 0.9666666984558105
530 0.9666666984558105
540 0.9666666984558105
550 0.9000000357627869
560 0.9000000357627869
570 0.9666666984558105
580 1.0
590 0.9666666984558105
600 1.0
610 0.9666666984558105
620 1.0
630 0.9666666984558105
640 0.9666666984558105
650 1.0
660 0.9666666984558105
670 1.0
680 0.9666666984558105
690 0.9666666984558105
700 0.9666666984558105
710 1.0
720 0.8666667342185974
730 1.0
740 0.9666666984558105
750 0.9333333969116211
760 1.0
770 0.9666666984558105
780 1.0
790 0.9666666984558105
800 1.0
810 1.0
820 1.0
830 0.9333333969116211
avg acc: 0.9726618941453435
test acc: 0.8754996503714463

6. 預測

  • 輸出的預測:是('pos':1, 'neg':0)字符串的編號
for batch in test_iterator:
    # batch_size個預測
    preds = rnn(batch.text).squeeze(1)
    preds = predice_test(preds)
    # print(preds)

    i = 0
    for text in batch.text:
        # 遍歷一句話里的每個單詞
        for word in text:
            print(TEXT.vocab.itos[word], end=' ')
    
        print('')
        # 輸出3句話
        if i == 3:
            break
        i = i + 1

    i = 0
    for pred in preds:
        idx = int(pred.item())
        print(idx, LABEL.vocab.itos[idx])
        # 輸出3個結果(標簽)
        if i == 3:
            break
        i = i + 1
    break  
Anyone <unk> Great A If <unk> Without The Brilliant This <unk> This If This Ten Absolutely For A This One Add a Just This I More What Brilliant Read <unk> 
who Classic story great you hires a <unk> . movie it is you is minutes fantastic pure touching is of this mesmerizing love is hope suspenseful a and the <unk> 
gives Waters , film 've a doubt mixed <unk> is with the like quite of ! <unk> movie a the little film the a this , script moving book <unk> 
this ! great in ever psychopath , with along terrible all greatest <unk> possibly people Whatever vampire . good funniest gem that interplay great group more , performances , interpretation 
1 pos
1 pos
1 pos
1 pos

可以寫成

import spacy
nlp = spacy.load('en_core_web_md')

def predict_sentiment(sentence):
    tokenized = [tok.text for tok in nlp.tokenizer(sentence)]
#     print(tokenized)   # ['This', 'film', 'is', 'terrible']
    indexed = [TEXT.vocab.stoi[t] for t in tokenized]
    tensor = torch.LongTensor(indexed).to(device)
    text = tensor.unsqueeze(0)  
    prediction = torch.sigmoid(rnn(tensor))
    return prediction.item()
predict_sentiment("This film is great")  # 1.0

導入模型,並測試

model.load_state_dict(torch.load('wordavg-model.pt'))
test_loss, test_acc = evaluate(rnn, test_iterator, criterion)
print(f'Test Loss: {test_loss:.3f} | Test Acc: {test_acc*100:.2f}%')


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM