Fairseq-快速可擴展的序列建模工具包


一種快速、可擴展的序列建模工具包,Pytorch的高級封裝庫,適用於機器翻譯、語言模型和篇章總結等建模任務。

  • 抽象

  • 注冊

  • 實現上的特點

抽象

Dataset:數據加載

Fairseq中的Dataset基本都是按功能逐層封裝,按需組合起來。所有數據加載的實現均位於fairseq/data下面。

兩個比較常用的數據處理類:

  • IndexedDataset直接處理/讀取,bin/raw文件。

  • LanguagepairDataset包含src和tgt兩個Dataset,用於處理成對的數據。比如在機器翻譯的中翻英任務中,處理中文和英文文本。

Option:參數定義

Fairseq中的參數統一使用argsparser庫實現,模型通用參數被定義在fairseq/option.py下。同時每個模型均有其特有參數,通過每個模型的add_args(parser)函數定義。

  • fairseq/option.py中定義了6類通用參數,對應的函數分別是get_preprocessing_parser(),get_training_parser(),get_generation_parser(),get_interactive_generation_parser(),get_eval_lm_parser()get_validation_parser。這6類通用參數又通過add_***_args()組裝起來。

  • 在模型各自的實現中,通過繼承接口中的add_args()添加模型特有的參數,比如fairseq\models\lstm.py中通過add_args()添加了LSTM模型的encoder-embed-dim,encoder-layers,encoder-bidirectional等參數。

Model:網絡模型的抽象

Fairseq中的Model負責模型的定義,包括各個模型的總體結構,每個模型提供argeparser供用戶傳入自定義參數。所有的模型定義均位於fairseq/models下。

所有的模型均繼承自類BaseFairseqModel,而BaseFairseqModel又繼承自torch.nn.Module,因此所有的Fairseq模型均可以作為其它Pytorch代碼的模塊。模型的具體結構,比如嵌入層的維度、隱藏層的個數由architectures定義。特別地,

  1. LanguageModelEncoderDecoderModel均直接繼承自BaseFairseqModelBaseFairseqModel主要提供add_args()build_model()等統一的接口,以及模型加載等功能。

  2. EncoderDecoder均直接繼承自torch.nn.ModuleLanguageModelEncoderDecoder包含EncoderDecoder

  3. Decoder包含一個output_layer抽象接口,BERT這樣的語言模型由於存在輸出,因此繼承的是Decoder

FairseqTask

Fairseq主要以FairseqTask為核心,使用FairseqTask將各個部分銜接起來。一個Task可以是TranslationTask(比如使用Transformer做翻譯),也可以是一個LanguageModelTask。所有的任務定義均位於fairseq/tasks下。一個FairseqTask實例需要實現以下功能:

  1. 字典存儲/加載。

  2. 提供加載、分切數據的幫助類,獲得裝載數據的Dataloaderiterator等。

  3. 創建模型。

  4. 創建criterion

  5. 循環訓練、驗證,直至收斂或達到指定訓練輪數。

FairseqTask實現的功能基本上包含了模型運行的全部要素,可以看到主函數的調用流程:

Criterion

所有的准則(criterion)定義在fairseq/criterions內,准則對給定的模型和小批量數據計算損失(Loss)。也就是:

\[loss=criterion(model,batch) \]

在Fairseq中,實現所謂的“混合專家”(mixture-of-experts)模型,准則(criterion)實現EM風格(EM-style)的訓練,以節約算力。

Optimizer

所有的優化器(optimizer)定義在fairseq/optim中,優化器根據梯度,更新模型參數。

Scheduler

定義在fairseq/lr_scheduler中。在訓練過程中,調整學習率。

注冊

注冊機制

Fairseq中許多組件都是公共的,模塊之間盡量解耦,需要一種方式指定應該跑哪一個Model,數據裝載使用哪一個Dataset。注冊機制在Fairseq中大量使用。

以FairseqTask的注冊機制為例,FairseqTask包含了多個子類,如TranslationTaskMaskedLMTaskLanguageModeling等。在fairseq/task/__init__.py中會通過for循環import該目錄下的所有文件,最后在TASK_REGISTRY中可以得到key:cls形式的模塊存儲器。其中,key為字符串,cls為模塊的cls對象。這種方式可以很方便的通過指定參數,導入想要的模塊。在函數裝飾器setup_task()register_task()中,通過TASK_REGISTRY載入和注冊task。

舉個例子,通過裝飾器進行注冊,比如:

@register_task('language_modeling')
class LanguageModelingTask(FairseqTask):
    ...

ModelCriterion部分都有該機制的身影。

在主函數train.py中,通過setup_task(),build_model(),build_criterion()中得到所需部分。

同樣地,可以使用注冊機制固化模型參數。一些模型僅僅有模型參數上的區別,本質並無區別,比如roberta_base,roberta_large。因此需要指定各個模型的具體默認參數,當然這些參數,用戶可以通過fairseq的參數系統進行指定。這些模型的具體參數同樣可以用注冊的方式固定下來,在使用時可以更加方便。

  1. 對於模型,使用@register_model裝飾器注冊。
@register_model('roberta')
class RobertaModel(FairLanguageModel):
    ...
  1. 對於具體的模型結構,使用@register_model_architecture裝飾器注冊。
@register_model_architecture('robtera','roberta_large')
def roberta_large_architecture(args):
    args.encoder_layers = getattr(args,'encoder_layers',24)
    args.encoder_embed_dim = getattr(args,'encoder_embed_dim',1024)
    ...
    base_architecture(args)

注冊的函數對象會在ARCH_CONFIG_REGISTERY中存儲,並在option.py中調用:

ARCH_CONFIG_REGISTRY[args.arch](args)

實現上的特點

Fairseq使用Pytorch實現,支持多機、多卡、混合精度訓練。提升速度,降低顯存占用。

分批次

Fairseq依據序列長度對源/目標序列進行分組,相似長度的序列作為一組,以減小對序列的補齊填充操作。每一個mini-batch內的樣本在訓練過程中不變,但每一輪訓練時都會打亂mini-batch間的順序。當在多卡、多機上運行時,每一個worker的mini-batches平均長度有所不同,以實現更有代表性(more representative)的迭代。

多GPU訓練

  • 使用NCCL2庫和torch.distributed作為GPU間的通信。

  • 每個GPU上保留一個模型副本。

  • 前向計算和反向傳播異步。Fairseq中每一層的梯度計算完成后,都會把結果存放到緩存中,當緩存大小達到某一個閾值之后,在一個后台線程中同步梯度,反向傳播照常進行。在每一個GPU上累加梯度,以減小worker上處理時間的方差,這樣就不必等待計算比較慢的worker。

如圖所示,圖a在同步梯度時,等待最慢的worker,因此產生了大量的等待時間(白色所示,idle)。但Fairseq同時采用了圖b和圖c的技術,反向傳播(back-propagation)和梯度同步(gradient synchronization)同時進行,並且累加梯度以減少worker上面處理時間的“抖動”,從而提升訓練速度。

混合精度訓練

Fairseq同時支持半精度浮點(half precision float point, FP16)和全精度浮點(full precision float point, FP32)的訓練和推斷。在前后向以及worker之間規約(all-reduce)時,使用FP16。但在參數更新時仍然采用FP32,以保證計算精度。由於FP16提供的精度有限,為了防止激活和梯度的下溢出,Fairseq實現了所謂的動態損失縮放(dynamic loss scaling)。當FP16的梯度在worker之間同步完成之后,將縮放到FP16的數字恢復為原來的FP32,並更新模型權重。

推斷優化

Fairseq通過增量解碼(incremental decoding)提供了更快的推理速度。所謂的增量解碼,就是在解碼時,將之前tokens處於激活beam狀態下的模型狀態(model states)緩存起來,以備后用,這樣每一個新的token進來,只需要計算新的狀態即可。也就是說,如果使用FairseqDecoder接口實現普通的解碼器,對於每一個輸出,都需要重新整個解碼器隱狀態,計算復雜度O(n^2)。而使用FairseqIncrementalDecoder接口實現增量解碼,就可以實現O(n)的解碼速度。

在訓練和推理階段,通過用戶指定的最大tokens數量,構建動態樣本數量的batch。並且Fairseq在保證准確率的前提下,支持FP16精度的推斷。相比於FP32,FP16推斷將解碼速度提高54%。注意:在Fairseq中,用戶通過指定max-tokens,Fairseq會自動構建不定數量的batch送入模型訓練。當然,用戶同樣可以通過batch-size指定一個批次中的最大樣本數。

Fairseq repo (Python): https://github.com/pytorch/fairseq
Paper: http://cn.arxiv.org/abs/1904.01038
Document: fairseq.readthedocs.io
https://zhuanlan.zhihu.com/p/100249351
https://zhuanlan.zhihu.com/p/100643955


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM