知識蒸餾loss求解方法


本文簡單介紹知識蒸餾教師模型與學生模型使用loss方法:

一 .loss求解方法

hard label:訓練的學生模型結果與真實標簽進行交叉熵loss,類似正常網絡訓練。

soft label:訓練的學生網絡與已經訓練好的教師網絡進行KL相對熵求解,可添加系數,如溫度,使其更soft。

知乎回答:loss是KL divergence,用來衡量兩個分布之間距離。而KL divergence在展開之后,第一項是原始預測分布的熵,由於是已知固定的,可以消去。第二項是 -q log p,叫做cross entropy,就是平時分類訓練使用的loss。與標簽label不同的是,這里的q是teacher model的預測輸出連續概率。而如果進一步假設q p都是基於softmax函數輸出的概率的話,求導之后形式就是 q - p。直觀理解就是讓student model的輸出盡量向teacher model的輸出概率靠近。
 
二.展示蒸餾網絡過程圖

 

 


三.展示代碼與結果

蒸餾模型分類loss代碼如下:

import torch
import torch.nn as nn
import numpy as np

loss_f = nn.KLDivLoss()

# 生成網絡輸出 以及 目標輸出
model_student = torch.from_numpy(np.array([[0.1132, 0.5477, 0.3390]])).float() # 假設學生模型輸出

model_teacher = torch.from_numpy(np.array([[0.8541, 0.0511, 0.0947]])).float() #假設教師模型輸出
label=torch.tensor([0])  # 真實標簽
loss_KD = loss_f(model_student, model_teacher)
L=nn.CrossEntropyLoss()
loss_SL=L(model_student,label)
lambda_ ,T=0.6,3  # 分別為設置權重參數,T為溫度系數
loss = (1 - lambda_) * loss_SL + lambda_ * T * T * loss_KD  # hint和jeff dean論文

print('\nloss: ', loss)

 

 

結果圖顯示:

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM