一、Distilling
訓練模型的優劣性取決於模型的泛化能力,在對預測數據進行預測時,會出現較好的預測結果;
通常情況下,復雜度高的網絡結構會具有較好的泛化能力,但是資源消耗較大,且存在信息冗余。而所謂的Distilling就是將復雜網絡中的有用信息提取出來遷移到一個更小的網絡上,這樣學習來的小網絡可以具備和大的復雜網絡想接近的性能效果,並且也大大的節省了計算資源。這個復雜的網絡可以看成一個教師,而小的網絡則可以看成是一個學生;蒸餾最終的目的是使得學生網絡可以具備老師網絡的性能,且降低模型復雜度,減少資源消耗。
二、pytorch-Distilling
import torch import torch.nn as nn import numpy as np from torch.nn import CrossEntropyLoss from torch.utils.data import TensorDataset,DataLoader,SequentialSampler class model(nn.Module): def __init__(self,input_dim,hidden_dim,output_dim): super(model,self).__init__() self.layer1 = nn.LSTM(input_dim,hidden_dim,output_dim,batch_first = True) self.layer2 = nn.Linear(hidden_dim,output_dim) def forward(self,inputs): layer1_output,layer1_hidden = self.layer1(inputs) layer2_output = self.layer2(layer1_output) layer2_output = layer2_output[:,-1,:]#取出一個batch中每個句子最后一個單詞的輸出向量即該句子的語義向量!!!!!!!!! return layer2_output #建立小模型 model_student = model(input_dim = 2,hidden_dim = 8,output_dim = 4) #建立大模型(此處仍然使用LSTM代替,可以使用訓練好的BERT等復雜模型) model_teacher = model(input_dim = 2,hidden_dim = 16,output_dim = 4) #設置輸入數據,此處只使用隨機生成的數據代替 inputs = torch.randn(4,6,2) true_label = torch.tensor([0,1,0,0]) #生成dataset dataset = TensorDataset(inputs,true_label) #生成dataloader sampler = SequentialSampler(inputs) dataloader = DataLoader(dataset = dataset,sampler = sampler,batch_size = 2) loss_fun = CrossEntropyLoss() criterion = nn.KLDivLoss()#KL散度 optimizer = torch.optim.SGD(model_student.parameters(),lr = 0.1,momentum = 0.9)#優化器,優化器中只傳入了學生模型的參數,因此此處只對學生模型進行參數更新,正好實現了教師模型參數不更新的目的 for step,batch in enumerate(dataloader): inputs = batch[0] labels = batch[1] #分別使用學生模型和教師模型對輸入數據進行計算 output_student = model_student(inputs) output_teacher = model_teacher(inputs) #計算學生模型和真實標簽之間的交叉熵損失函數值 loss_hard = loss_fun(output_student,labels) #計算學生模型預測結果和教師模型預測結果之間的KL散度 loss_soft = criterion(output_student,output_teacher) loss = 0.9*loss_soft + 0.1*loss_hard print(loss) optimizer.zero_grad() loss.backward() optimizer.step()
三、Reference
https://blog.csdn.net/libaominshouzhang/article/details/109777317
