Torchtext指南 (側重於NMT)


Torchtext指南 (側重於NMT)

torchtext是一個對於NLP來說非常棒的預處理數據的工具。

本文記錄一下自己學習的過程,側重於NMT。

一個基本的操作流程:

  • 創建Field,定義通用的文本處理操作:
from torchtext import data, datasets
SRC = data.Field(...)
TRG = data.Field(...)
  • 加載你的數據集
train_data, valid_data, test_data = datasets.TranslationDataset.splits(...)
  • 創建詞匯表
SRC.build_vocab(train_data.src)
TRG.build_vocab(train_data.trg)
  • 最后生成迭代器進行Batch操作
train_iter = data.BucketIterator(...)
valid_iter = data.BucketIterator(...)

Field

貌似有好幾種,對於我自己來說常用的就是:

torchtext.data.Field(sequential=True, use_vocab=True, init_token=None, eos_token=None, fix_length=None, dtype=torch.int64, preprocessing=None, postprocessing=None, lower=False, tokenize=None, tokenizer_language='en', include_lengths=False, batch_first=False, pad_token='<pad>', unk_token='<unk>', pad_first=False, truncate_first=False, stop_words=None, is_target=False)

參數具體詳解:

  • sequential: 是否把數據表示成序列,如果是False, 不能使用分詞 默認值: True.

  • use_vocab: 是否使用詞典對象. 如果是False 數據的類型必須已經是數值類型. 默認值: True.

  • init_token: 每一條數據的起始字符 默認值: None.

  • eos_token: EOS 默認值: None.

  • fix_length: 修改每條數據的長度為該值,不夠的用pad_token補全. 默認值: None.

  • tensor_type: 把數據轉換成的tensor類型 默認值: torch.LongTensor.

  • preprocessing:在分詞之后和數值化之前使用的管道 默認值: None.

  • postprocessing: 數值化之后和轉化成tensor之前使用的管道默認值: None.

  • lower: 是否把數據轉化為小寫 默認值: False.

  • tokenize: 分詞函數. 默認值: str.split.

  • include_lengths: 是否返回一個已經補全的最小batch的元組和和一個包含每條數據長度的列表 . 默認值: False.

  • batch_first: 是否Batch first. 默認值: False.

  • pad_token: PAD 默認值: " ".

  • unk_token: UNK 默認值: " ".

  • pad_first: 是否補全第一個字符. 默認值: False.

datasets(這里只講TranslationDataset)

torchtext.datasets.TranslationDataset(path, exts, fields, **kwargs)
  • path: 兩種語言的數據文件的路徑的公共前綴

  • exts: 包含每種語言路徑擴展名的tuple

  • fields: 包含將用於每種語言的Field的tuple

  • **kwargs: 等等

torchtext.datasets.TranslationDataset.splits(path, exts, fields, **kwargs)

機器翻譯的話train, dev, test一般是分開的,這時候就要用splits啦。

classmethod splits(path=None, root='.data', train=None, validation=None, test=None, **kwargs)
  • path: 兩種語言的數據文件的路徑的公共前綴

  • train: train路徑

  • valiadation: dev路徑

  • test: test路徑

Iterator

torchtext.data.Iterator

class torchtext.data.Iterator(dataset, batch_size, sort_key=None, device=None, batch_size_fn=None, train=True, repeat=False, shuffle=None, sort=None, sort_within_batch=None)
  • dataset: 前面定義的dataset
  • batch_size: batch大小
  • batch_size_fn: 用來產生動態batch
  • sort_key: 排序的key w
  • train: 是不是train data
  • repeat: 在不在不同的epoch中重復
  • shuffle: 大打不打亂數據
  • sort: 是否排序
  • sort_within_batch: batch內部是否排序
  • device: cpu or gpu

BucketIterator

class torchtext.data.BucketIterator(dataset, batch_size, sort_key=None, device=None, batch_size_fn=None, train=True, repeat=False, shuffle=None, sort=None, sort_within_batch=None)

定義一個迭代器,將類似長度的示例一起批處理。

減少在每一個epoch中shuffled batches需要padding的量

NMT常用寫法

from torchtext import data, datasets

UNK_TOKEN = "<unk>"
PAD_TOKEN = "<pad>"    
SOS_TOKEN = "<s>"
EOS_TOKEN = "</s>"

LOWER = True
MAX_LEN = 30
MIN_FREQ = 10

DEVICE=torch.device('cuda:0')

#Step 1
#一般我用的是已經分好詞的數據, 所以tokenize=None
#如果用bpe的話,共享vocab,那就只用一個field就好啦
SRC = data.Field(tokenize=None, 
                 batch_first=True, lower=LOWER, include_lengths=True,
                 unk_token=UNK_TOKEN, pad_token=PAD_TOKEN,
                 init_token=None, eos_token=EOS_TOKEN)
                     
TRG = data.Field(tokenize=None, 
                 batch_first=True, lower=LOWER, include_lengths=True,
                 unk_token=UNK_TOKEN, pad_token=PAD_TOKEN,
                 init_token=SOS_TOKEN, eos_token=EOS_TOKEN)


#Step 2
train_data, valid_data, test_data = datasets.TranslationDataset.splits(
path='./WMT2014_en-de/', train= 'xxx',validation='yyy', test='zzz', exts=('.de', '.en'), fields=(SRC, TRG),
filter_pred=lambda x: len(vars(x)['src']) <= MAX_LEN and len(vars(x)['trg']) <= MAX_LEN)

#Step 3
SRC.build_vocab(train_data.src, min_freq=MIN_FREQ)
TRG.build_vocab(train_data.trg, min_freq=MIN_FREQ)

#Step 4

train_iter = data.BucketIterator(train_data, batch_size=64, train=True, 
                                 sort_within_batch=True, 
                                 sort_key=lambda x: (len(x.src), len(x.trg)), repeat=False,
                                 device=DEVICE)

# 如果用一些未排序的外部文件進行valid,經常有問題。
#為了方便,將batch大小設置為1
valid_iter = data.BucketIterator(valid_data, batch_size=64, train=False, 
                                 sort_within_batch=True, 
                                 sort_key=lambda x: (len(x.src), len(x.trg)), repeat=False,
                                 device=DEVICE)

相關介紹:

  1. https://zhuanlan.zhihu.com/p/31139113?edition=yidianzixun&utm_source=yidianzixun&yidian_docid=0IIxNSe7

  2. https://blog.csdn.net/u012436149/article/details/79310176


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM