本文簡單介紹知識蒸餾教師模型與學生模型使用loss方法:
一 .loss求解方法
hard label:訓練的學生模型結果與真實標簽進行交叉熵loss,類似正常網絡訓練。
soft label:訓練的學生網絡與已經訓練好的教師網絡進行KL相對熵求解,可添加系數,如溫度,使其更soft。
知乎回答:loss是KL divergence,用來衡量兩個分布之間距離。而KL divergence在展開之后,第一項是原始預測分布的熵,由於是已知固定的,可以消去。第二項是 -q log p,叫做cross entropy,就是平時分類訓練使用的loss。與標簽label不同的是,這里的q是teacher model的預測輸出連續概率。而如果進一步假設q p都是基於softmax函數輸出的概率的話,求導之后形式就是 q - p。直觀理解就是讓student model的輸出盡量向teacher model的輸出概率靠近。
二.展示蒸餾網絡過程圖
三.展示代碼與結果

三.展示代碼與結果
蒸餾模型分類loss代碼如下:
import torch import torch.nn as nn import numpy as np loss_f = nn.KLDivLoss() # 生成網絡輸出 以及 目標輸出 model_student = torch.from_numpy(np.array([[0.1132, 0.5477, 0.3390]])).float() # 假設學生模型輸出 model_teacher = torch.from_numpy(np.array([[0.8541, 0.0511, 0.0947]])).float() #假設教師模型輸出 label=torch.tensor([0]) # 真實標簽 loss_KD = loss_f(model_student, model_teacher) L=nn.CrossEntropyLoss() loss_SL=L(model_student,label) lambda_ ,T=0.6,3 # 分別為設置權重參數,T為溫度系數 loss = (1 - lambda_) * loss_SL + lambda_ * T * T * loss_KD # hint和jeff dean論文 print('\nloss: ', loss)
結果圖顯示: