知識蒸餾入門、實踐和相關庫的介紹及使用


本文已經過時,請前往: https://www.cnblogs.com/infgrad/p/13767918.html 查看知識蒸餾的最新文章

1 前言

知識蒸餾,其目的是為了讓小模型學到大模型的知識,通俗說,讓student模型的輸出接近(擬合)teacher模型的輸出。所以知識蒸餾的重點在於擬合二字,即我們要定義一個方法去衡量student模型和teacher模型接近程度,說白了就是損失函數。

為什么我們需要知識蒸餾?因為大模型推理慢難以應用到工業界。小模型直接進行訓練又不如蒸餾得到的效果好。

下面介紹四個比較熱門的蒸餾文章,這四個本人均有實踐,這也是這篇文章的干貨所在。

2 知識蒸餾的開山之作

Hinton 在論文: Distilling the Knowledge in a Neural Network 提出了知識蒸餾的方法。網上關於這方面的資料實在是太多了,我就簡單總結下吧。
損失函數:$$Loss = aL_{soft} + (1-a)L_{hard}$$
其中\(L_{soft}\)是StudentModel和TeacherModel的輸出的交叉熵,\(L_{hard}\)是StudentModel輸出和真實標簽的交叉熵。
這里提個問題,為什么需要\(L_{soft}\)?這可是知識蒸餾的靈魂問題,知道的可以在評論里講一下,嘿嘿。
再細說一下\(L_{soft}\)。我們知道TeacherModel的輸出是經過Softmax處理的,指數e拉大了各個類別之間的差距,最終輸出結果特別像一個one-hot向量,這樣不利於StudentModel的學習,因此我們希望輸出更加一些。因此我們需要改一下softmax函數:

顯然T越大輸出越軟。這樣改完之后,對比原始softmax,梯度相當於乘了\(1/T^2\),因此\(L_{soft}\)需要再乘以\(T^2\)來與\(L_{hard}\)在一個數量級上。

最后再放一張這個算法的整體框架圖(圖片來自https://blog.csdn.net/nature553863/article/details/80568658)以便大家理解:

3 TinyBert

3.1 基本思路介紹

首先說到對Bert的蒸餾大家肯定會想到就用微調好的Bert作為TeacherModel去訓練一個StudentModel,沒錯目前就是這么干的。那么下面的問題就是我們選取什么模型作為StudentModel,這個已經有一些嘗試了,比如有人使用BiLSTM,但是更多的人還是繼續使用了Bert,只不過這個Bert會比原始的Bert小。在TinyBert中,StudentModel使用的是減少了embedding size、hidden size和num hidden layers的小bert。
所以到此我們就解決了StudentModel的選取問題。

那么新的問題又來了,我們怎么初始化StudentModel?最直接的解決方案就是隨機化,模型不都是這么訓練的嗎?但是這種效果就真的好嗎?我看未必,如果沒問題那么為啥那么多人用預訓練模型?所以啊,我們需要一個比較好StudentModel的參數,確切說,我們需要一個預訓練的StudentModel,那么怎么獲取一個預訓練的StudentModel,TinyBert給出的答案就是咱們再用預訓練好的Bert蒸餾出一個預訓練好的StudentModel。

Ok,TinyBert基本講完了,我來簡單總結下,TinyBert一共分為兩步:

  1. 用pretrained bert蒸餾一個pretrained TinyBert
  2. 用fine-tuned bert蒸餾一個fine-tuned TinyBert( 它的初始化參數就是第一步里pretrained TinyBert)

3.2 損失函數

OK,TinyBert基本思路講完了。下面說一說最重要的一點:TinyBert具體是怎么蒸餾的,即損失函數。

直接看公式吧:

解釋下這個公式:

  • \(m\):整數,0到StudentModel層數之間
  • \(S_m\):StudentModel第m層的輸出
  • \(g(m)\):映射函數,實際意義是讓StudentModel的第m層學習TeacherModel第g(m)層的輸出
  • \(T_{g(m)}\):TeacherModel的第g(m)層的輸出
  • \(M\):StudentModel隱層數量,那么StudentModel第M+1層就是預測層的輸出了(logits)
  • \(L_{embd}(S_0,T_0)\):word embedding層的損失函數,用的是MSE
  • \(L_{hidden}和L_{attn}\):hidden層和attention層的損失函數,都是MSE
  • \(L_{pred}\):預測層損失函數,用的交叉熵,其他文獻也有用KL-Distance的,其實反向傳播的時候都一樣。

再補充一句:在進行蒸餾的時候,會先進行隱層蒸餾(即m<=M),然后再執行m=M+1時的蒸餾。
總結一下,有助於大家理解:TinyBERT在蒸餾的時候,不僅要讓StudentModel學到最后一層的輸出,還要學到中間幾層的輸出。換言之,StudengModel的某一隱層可以學到TeacherModel若干隱層的輸出。感覺蒸餾的粒度比較細,我覺得可以叫做LayerBasedDistillation。

3.3 實戰經驗

  1. 在硬件和數據有限的條件下,我們很難做預訓練模型的蒸餾,但是可以借鑒TinyBERT的思路,直接做TaskSpecific的蒸餾,至於如何初始化模型,我有兩個建議:要不直接拿原始Teacher模型初始化,要不找一個別人預訓練好的小模型進行初始化。我直接用的RBT3模型初始化,效果很好。
  2. 蒸餾完StudentModel,一定要測StudentModel的泛化能力,切記要測試StudentModel的泛化能力。不要害怕,StudentModel泛化能力也不差,但是你要測下。。。。
  3. 靈活一些,蒸餾學習目前沒有一個統一的方法,有很多地方可以自己改一改試一試。

4 DistilBert

4.1 基本思路

說完了TinyBert,想再和大家聊一聊DistilBert,DistilBert要比TinyBert簡單不少,我就少用些文字,DistilBert使用預訓練好的Bert作為TeacherModel訓練了一個StudentModel,這里的StudentModel就是層數少的Bert,注意這里得到的DistilBERT本質上還是一個預訓練模型,因此用到具體下游任務上時,還是需要用專門的數據去微調,這里就是純粹的微調,不需要考慮再用蒸餾學習輔助。HuggingFace已經提供了若干蒸餾好的預訓練模型,大家直接拿過來當Bert用就行了,但是好些沒有中文的。。。

4.2 損失函數

DIstillBERT的損失函數:Lce + Lmlm + Lcos。

  • Lce,StudentModel和TeacherModellogits的交叉熵
  • Lmlm,StudentModel中遮擋語言模型的損失函數
  • Lcos,StudentModel和TeacherModel hidden states的余弦損失函數,注意在TinyBERT里用的是MSE,而且還考慮了attention的MSE。

5 BERT-of-Theseus

這個准確的來說不是知識蒸餾,但是它確實減小了模型體積,而且思路和TinyBERT、DistillBERT都有類似,因此就放到這里講了。這個思路非常優雅,真的非常優雅,它通過隨機使用小模型的一層替換大模型中若干層,來完成訓練。我來舉一個例子:假設大模型是input->tfc1->tfc2->tfc3->tfc4->tfc5->tfc6->output,然后再定義一個小模型input->sfc1->sfc2->sfc3->output。再訓練過程中還是要訓練大模型,只是在每一步中,會隨機的將(tfc1,tfc2),(tfc3,tfc4),(tfc5,tfc6)替換為sfc1,sfc2,sfc3,而且隨着訓練的進行,替換的概率不斷變大,因此最后就是在訓練一個小模型。
放一張圖便於大家理解

方式優雅,作者提供了源碼,強烈推薦大家用一用。

6 MiniLM

剛剛發布的一篇新論文, 也是關於BERT蒸餾的,我簡單總結下三個創新點:

  1. 先用TeacherModel蒸餾一個中等模型,再用中等模型蒸餾一個較小的StudentModel。只有在StudentModel很小的時候才會這么做。
  2. 只對最后一個隱層做蒸餾,作者認為這樣可以讓StudentModel有更大的自由空間,而且這樣對StudentModel架構的要求就變得寬松了
  3. 對於最后一個隱層主要是對attention權重做學習,我自己沒看的太明白,作者還沒釋放源碼,所以我就不講出來了,以免誤人子弟,知道朋友可以評論告知。

放一下圖以便大家理解:

7 知識蒸餾通用框架

7.1 KnowledgeDistillation庫

我自己實現了一個基於Pytorch簡單的知識蒸餾框架,有興趣的朋友可以試一試,使用這個框架可以實現TInyBERT、DistillBERT等,定制蒸餾靈活。目前剛出了第一版本,還有很多東西需要修改的,有興趣的朋友可以來一起修改這個庫。
Pypi:https://pypi.org/project/KnowledgeDistillation/
Github:https://github.com/DunZhang/KnowledgeDistillation

給大家提供一個范例代碼,使用12層bert蒸餾3層bert,使用TinyBERT的損失函數,完整可以直接運行,不需要外部數據:

import torch
import logging
import numpy as np
from transformers import BertModel, BertConfig
from torch.utils.data import DataLoader, RandomSampler, TensorDataset

from knowledge_distillation import KnowledgeDistiller, MultiLayerBasedDistillationLoss
from knowledge_distillation import MultiLayerBasedDistillationEvaluator

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
# Some global variables
train_batch_size = 16
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
learning_rate = 2e-5
num_epoch = 10
# Teacher Model
bert_config = BertConfig(num_hidden_layers=12, output_hidden_states=True, output_attentions=True)
teacher_model = BertModel(bert_config)
# Student Model
bert_config = BertConfig(num_hidden_layers=3, output_hidden_states=True, output_attentions=True)
student_model = BertModel(bert_config)

### Train data loader
input_ids = torch.LongTensor(np.random.randint(100, 1000, (100000, 64)))
attention_mask = torch.LongTensor(np.ones((100000, 64)))
token_type_ids = torch.LongTensor(np.zeros((100000, 64)))
train_data = TensorDataset(input_ids, attention_mask, token_type_ids)
train_sampler = RandomSampler(train_data)
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=train_batch_size)


### Train data adaptor
### It is a function that turn batch_data (from train_dataloader) to the inputs of teacher_model and student_model
### You can define your own train_data_adaptor. Remember the input must be device and batch_data.
###  The output is either dict or tuple, but must consistent with you model's input
def train_data_adaptor(device, batch_data):
    batch_data = tuple(t.to(device) for t in batch_data)
    batch_data_dict = {"input_ids": batch_data[0],
                       "attention_mask": batch_data[1],
                       "token_type_ids": batch_data[2], }
    # In this case, the teacher and student use the same input
    return batch_data_dict, batch_data_dict


### The loss model is the key for this generation.
### We have already provided a general loss model for distilling multi bert layer
### In most cases, you can directly use this model.
#### First, we should define a distill_config which indicates how to compute ths loss between teacher and student.
#### distill_config is a list-object, each item indicates how to calculate loss.
#### It also defines which output of which layer to calculate loss.
#### type "ts_distill" means that we compute loss between teacher and student
#### type "hard_distill" means that we compute loss between student output and ground truth
#### loss_function can be mse, cross_entropy or cos. Args is extra parameters in this loss_function
#### loss_function(x,y,**args)
distill_config = [
    {"type": "ts_distill",
     "teacher_layer_name": "embedding_layer", "teacher_layer_output_name": "embedding",
     "student_layer_name": "embedding_layer", "student_layer_output_name": "embedding",
     "loss": {"loss_function": "mse", "args": {}}, "weight": 1.0
     },
    {"type": "ts_distill",
     "teacher_layer_name": "bert_layer4", "teacher_layer_output_name": "hidden_states",
     "student_layer_name": "bert_layer1", "student_layer_output_name": "hidden_states",
     "loss": {"loss_function": "mse", "args": {}}, "weight": 1.0
     },
    {"type": "ts_distill",
     "teacher_layer_name": "bert_layer4", "teacher_layer_output_name": "attention",
     "student_layer_name": "bert_layer1", "student_layer_output_name": "attention",
     "loss": {"loss_function": "mse", "args": {}}, "weight": 1.0
     },
    {"type": "ts_distill",
     "teacher_layer_name": "bert_layer8", "teacher_layer_output_name": "hidden_states",
     "student_layer_name": "bert_layer2", "student_layer_output_name": "hidden_states",
     "loss": {"loss_function": "mse", "args": {}}, "weight": 1.0
     },
    {"type": "ts_distill",
     "teacher_layer_name": "bert_layer8", "teacher_layer_output_name": "attention",
     "student_layer_name": "bert_layer2", "student_layer_output_name": "attention",
     "loss": {"loss_function": "mse", "args": {}}, "weight": 1.0
     },
    {"type": "ts_distill",
     "teacher_layer_name": "bert_layer12", "teacher_layer_output_name": "hidden_states",
     "student_layer_name": "bert_layer3", "student_layer_output_name": "hidden_states",
     "loss": {"loss_function": "mse", "args": {}}, "weight": 1.0
     },
    {"type": "ts_distill",
     "teacher_layer_name": "bert_layer12", "teacher_layer_output_name": "attention",
     "student_layer_name": "bert_layer3", "student_layer_output_name": "attention",
     "loss": {"loss_function": "mse", "args": {}}, "weight": 1.0
     },
]

### teacher_output_adaptor and student_output_adaptor
### In most cases, model's output is tuple-object, However, in our package, we need the output is dict-object,
### like: { "layer_name":{"output_name":value} .... }
### Hence, the output adaptor is to turn your model's output to dict-object output
### In my case, teacher and student can use one adaptor
def output_adaptor(model_output):
    last_hidden_state, pooler_output, hidden_states, attentions = model_output
    output = {"embedding_layer": {"embedding": hidden_states[0]}}
    for idx in range(len(attentions)):
        output["bert_layer" + str(idx + 1)] = {"hidden_states": hidden_states[idx + 1],
                                               "attention": attentions[idx]}
    return output


# loss_model
loss_model = MultiLayerBasedDistillationLoss(distill_config=distill_config,
                                             teacher_output_adaptor=output_adaptor,
                                             student_output_adaptor=output_adaptor)
# optimizer
param_optimizer = list(student_model.named_parameters())
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
    {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
    {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]
optimizer = torch.optim.Adam(params=optimizer_grouped_parameters, lr=learning_rate)
# evaluator
evaluator = MultiLayerBasedDistillationEvaluator(save_dir=None, save_step=None, print_loss_step=20)
# Get a KnowledgeDistiller
distiller = KnowledgeDistiller(teacher_model=teacher_model, student_model=student_model,
                               train_dataloader=train_dataloader, dev_dataloader=None,
                               train_data_adaptor=train_data_adaptor, dev_data_adaptor=None,
                               device=device, loss_model=loss_model, optimizer=optimizer,
                               evaluator=evaluator, num_epoch=num_epoch)
# start distillate
distiller.distillate()
7.2 TextBrewer庫

介紹完了自己的庫,再介紹一個知識蒸餾庫,這個庫是由哈工大搞的,比我的好多了,哈哈哈哈哈哈哈,我建議大家star我的庫,然后使用哈工大的庫。
Github:
https://github.com/airaria/TextBrewer

在這里同樣的也提供一個完整可運行的代碼,且不需要任何外部數據:

import torch
import numpy as np
import pickle
import textbrewer
from textbrewer import GeneralDistiller
from textbrewer import TrainingConfig, DistillationConfig
from transformers import BertConfig, BertModel
from torch.utils.data import DataLoader, RandomSampler, TensorDataset

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
## 定義模型
bert_config = BertConfig(num_hidden_layers=12, output_hidden_states=True, output_attentions=True)
teacher_model = BertModel(bert_config).to(device)
bert_config = BertConfig(num_hidden_layers=3, output_hidden_states=True, output_attentions=True)
student_model = BertModel(bert_config).to(device)

# optimizer
param_optimizer = list(student_model.named_parameters())
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
    {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
    {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]
optimizer = torch.optim.Adam(params=optimizer_grouped_parameters, lr=2e-5)

### data
input_ids = torch.LongTensor(np.random.randint(100, 1000, (100000, 64)))
attention_mask = torch.LongTensor(np.ones((100000, 64)))
token_type_ids = torch.LongTensor(np.zeros((100000, 64)))
train_data = TensorDataset(input_ids, attention_mask, token_type_ids)
train_sampler = RandomSampler(train_data)
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=16)


# Define an adaptor for translating the model inputs and outputs
# 整合成蒸餾器需要的數據格式
# key需要是固定的???

def bert_adaptor(batch, model_outputs):
    last_hidden_state, pooler_output, hidden_states, attentions = model_outputs
    hidden_states = list(hidden_states)
    hidden_states.append(pooler_output)
    output = {"inputs_mask": batch[1],
              "attention": attentions,
              "hidden": hidden_states}
    return output


# Training configuration
train_config = TrainingConfig(gradient_accumulation_steps=1,
                              ckpt_frequency=10,
                              ckpt_epoch_frequency=1,
                              log_dir='logs',
                              output_dir='saved_models',
                              device='cuda')
# Distillation configuration
# Matching different layers of the student and the teacher
# 重要,如何蒸餾的定義
# 不支持自定義損失函數
# 不支持CLS LOSS,但是可以強行寫在hidden loss里面
distill_config = DistillationConfig(
    intermediate_matches=[
        {'layer_T': 0, 'layer_S': 0, 'feature': 'hidden', 'loss': 'hidden_mse', 'weight': 1},  # embedding loss
        {'layer_T': 4, 'layer_S': 1, 'feature': 'hidden', 'loss': 'hidden_mse', 'weight': 1},  # hidden loss
        {'layer_T': 8, 'layer_S': 2, 'feature': 'hidden', 'loss': 'hidden_mse', 'weight': 1},
        {'layer_T': 12, 'layer_S': 3, 'feature': 'hidden', 'loss': 'hidden_mse', 'weight': 1},
        {'layer_T': 3, 'layer_S': 0, 'feature': 'attention', 'loss': 'attention_mse', 'weight': 1},  # attention loss
        {'layer_T': 7, 'layer_S': 1, 'feature': 'attention', 'loss': 'attention_mse', 'weight': 1},
        {'layer_T': 11, 'layer_S': 2, 'feature': 'attention', 'loss': 'attention_mse', 'weight': 1},
        {'layer_T': 12, 'layer_S': 3, 'feature': 'hidden', 'loss': 'hidden_mse', 'weight': 1},  # 其實是CLS loss
    ]

)

# Build distiller
distiller = GeneralDistiller(
    train_config=train_config, distill_config=distill_config,
    model_T=teacher_model, model_S=student_model,
    adaptor_T=bert_adaptor, adaptor_S=bert_adaptor)

# Start!
# callbacker 可以在dev上進行評估
# 注意存的是state_dict
with distiller:
    distiller.train(optimizer=optimizer, scheduler=None, dataloader=train_dataloader, num_epochs=10, callback=None)

8 其它加速BERT的方法

還有很多其他加速BERT的方法,我就不細說了,有興趣的可以研究下:

  1. 提升硬件,目前看性價比較高的是RTX2070Super和RTXTitan
  2. 提升深度學習框架版本必然能提升訓練和推理速度。比如高版本的TensorFlow會支持mkldnn,AVX指令集。
  3. ONNXRuntime(這個主要用在推理中)
  4. BERT的量化
  5. StructedDropout了解一下,但是這個最好用在預訓練上,那不然效果不好,ICLR2020的最新論文:reducing transformer depth on demand with structured dropout

文章可以隨意轉載,但請務必注明出處:


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM