Pytorch 蒸餾(Distilling)


一、Distilling

訓練模型的優劣性取決於模型的泛化能力,在對預測數據進行預測時,會出現較好的預測結果;

通常情況下,復雜度高的網絡結構會具有較好的泛化能力,但是資源消耗較大,且存在信息冗余。而所謂的Distilling就是將復雜網絡中的有用信息提取出來遷移到一個更小的網絡上,這樣學習來的小網絡可以具備和大的復雜網絡想接近的性能效果,並且也大大的節省了計算資源。這個復雜的網絡可以看成一個教師,而小的網絡則可以看成是一個學生;蒸餾最終的目的是使得學生網絡可以具備老師網絡的性能,且降低模型復雜度,減少資源消耗。

二、pytorch-Distilling

import torch
import torch.nn as nn
import numpy as np

from torch.nn import CrossEntropyLoss
from torch.utils.data import TensorDataset,DataLoader,SequentialSampler

class model(nn.Module):
    def __init__(self,input_dim,hidden_dim,output_dim):
        super(model,self).__init__()
        self.layer1 = nn.LSTM(input_dim,hidden_dim,output_dim,batch_first = True)
        self.layer2 = nn.Linear(hidden_dim,output_dim)
    def forward(self,inputs):
        layer1_output,layer1_hidden = self.layer1(inputs)
        layer2_output = self.layer2(layer1_output)
        layer2_output = layer2_output[:,-1,:]#取出一個batch中每個句子最后一個單詞的輸出向量即該句子的語義向量!!!!!!!!!
        return layer2_output

#建立小模型
model_student = model(input_dim = 2,hidden_dim = 8,output_dim = 4)

#建立大模型(此處仍然使用LSTM代替,可以使用訓練好的BERT等復雜模型)
model_teacher = model(input_dim = 2,hidden_dim = 16,output_dim = 4)

#設置輸入數據,此處只使用隨機生成的數據代替
inputs = torch.randn(4,6,2)
true_label = torch.tensor([0,1,0,0])

#生成dataset
dataset = TensorDataset(inputs,true_label)

#生成dataloader
sampler = SequentialSampler(inputs)
dataloader = DataLoader(dataset = dataset,sampler = sampler,batch_size = 2)

loss_fun = CrossEntropyLoss()
criterion  = nn.KLDivLoss()#KL散度
optimizer = torch.optim.SGD(model_student.parameters(),lr = 0.1,momentum = 0.9)#優化器,優化器中只傳入了學生模型的參數,因此此處只對學生模型進行參數更新,正好實現了教師模型參數不更新的目的

for step,batch in enumerate(dataloader):
    inputs = batch[0]
    labels = batch[1]
    
    #分別使用學生模型和教師模型對輸入數據進行計算
    output_student = model_student(inputs)
    output_teacher = model_teacher(inputs)
    
    #計算學生模型和真實標簽之間的交叉熵損失函數值
    loss_hard = loss_fun(output_student,labels)
    
    #計算學生模型預測結果和教師模型預測結果之間的KL散度
    loss_soft = criterion(output_student,output_teacher)
    
    loss = 0.9*loss_soft + 0.1*loss_hard
    print(loss)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

三、Reference

https://blog.csdn.net/libaominshouzhang/article/details/109777317

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2026 CODEPRJ.COM