【tensorflow2.0】損失函數losses


一般來說,監督學習的目標函數由損失函數和正則化項組成。(Objective = Loss + Regularization)

對於keras模型,目標函數中的正則化項一般在各層中指定,例如使用Dense的 kernel_regularizer 和 bias_regularizer等參數指定權重使用l1或者l2正則化項,此外還可以用kernel_constraint 和 bias_constraint等參數約束權重的取值范圍,這也是一種正則化手段。

損失函數在模型編譯時候指定。對於回歸模型,通常使用的損失函數是平方損失函數 mean_squared_error。

對於二分類模型,通常使用的是二元交叉熵損失函數 binary_crossentropy。

對於多分類模型,如果label是類別序號編碼的,則使用類別交叉熵損失函數 categorical_crossentropy。如果label進行了one-hot編碼,則需要使用稀疏類別交叉熵損失函數 sparse_categorical_crossentropy。

如果有需要,也可以自定義損失函數,自定義損失函數需要接收兩個張量y_true,y_pred作為輸入參數,並輸出一個標量作為損失函數值。

import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras import layers,models,losses,regularizers,constraints

一,損失函數和正則化項

tf.keras.backend.clear_session()
 
model = models.Sequential()
model.add(layers.Dense(64, input_dim=64,
                kernel_regularizer=regularizers.l2(0.01), 
                activity_regularizer=regularizers.l1(0.01),
                kernel_constraint = constraints.MaxNorm(max_value=2, axis=0))) 
model.add(layers.Dense(10,
        kernel_regularizer=regularizers.l1_l2(0.01,0.01),activation = "sigmoid"))
model.compile(optimizer = "rmsprop",
        loss = "sparse_categorical_crossentropy",metrics = ["AUC"])
model.summary()
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense (Dense)                (None, 64)                4160      
_________________________________________________________________
dense_1 (Dense)              (None, 10)                650       
=================================================================
Total params: 4,810
Trainable params: 4,810
Non-trainable params: 0
_________________________________________________________________

二,內置損失函數

內置的損失函數一般有類的實現和函數的實現兩種形式。

如:CategoricalCrossentropy 和 categorical_crossentropy 都是類別交叉熵損失函數,前者是類的實現形式,后者是函數的實現形式。

常用的一些內置損失函數說明如下。

  • mean_squared_error(平方差誤差損失,用於回歸,簡寫為 mse, 類實現形式為 MeanSquaredError 和 MSE)

  • mean_absolute_error (絕對值誤差損失,用於回歸,簡寫為 mae, 類實現形式為 MeanAbsoluteError 和 MAE)

  • mean_absolute_percentage_error (平均百分比誤差損失,用於回歸,簡寫為 mape, 類實現形式為 MeanAbsolutePercentageError 和 MAPE)

  • Huber(Huber損失,只有類實現形式,用於回歸,介於mse和mae之間,對異常值比較魯棒,相對mse有一定的優勢)

  • binary_crossentropy(二元交叉熵,用於二分類,類實現形式為 BinaryCrossentropy)

  • categorical_crossentropy(類別交叉熵,用於多分類,要求label為onehot編碼,類實現形式為 CategoricalCrossentropy)

  • sparse_categorical_crossentropy(稀疏類別交叉熵,用於多分類,要求label為序號編碼形式,類實現形式為 SparseCategoricalCrossentropy)

  • hinge(合頁損失函數,用於二分類,最著名的應用是作為支持向量機SVM的損失函數,類實現形式為 Hinge)

  • kld(相對熵損失,也叫KL散度,常用於最大期望算法EM的損失函數,兩個概率分布差異的一種信息度量。類實現形式為 KLDivergence 或 KLD)

  • cosine_similarity(余弦相似度,可用於多分類,類實現形式為 CosineSimilarity)

三,自定義損失函數

自定義損失函數接收兩個張量y_true,y_pred作為輸入參數,並輸出一個標量作為損失函數值。

也可以對tf.keras.losses.Loss進行子類化,重寫call方法實現損失的計算邏輯,從而得到損失函數的類的實現。

下面是一個Focal Loss的自定義實現示范。Focal Loss是一種對binary_crossentropy的改進損失函數形式。

在類別不平衡和存在難以訓練樣本的情形下相對於二元交叉熵能夠取得更好的效果。

詳見《如何評價Kaiming的Focal Loss for Dense Object Detection?》

https://www.zhihu.com/question/63581984

def focal_loss(gamma=2., alpha=0.25):
 
    def focal_loss_fixed(y_true, y_pred):
        pt_1 = tf.where(tf.equal(y_true, 1), y_pred, tf.ones_like(y_pred))
        pt_0 = tf.where(tf.equal(y_true, 0), y_pred, tf.zeros_like(y_pred))
        loss = -tf.sum(alpha * tf.pow(1. - pt_1, gamma) * tf.log(1e-07+pt_1)) \
           -tf.sum((1-alpha) * tf.pow( pt_0, gamma) * tf.log(1. - pt_0 + 1e-07))
        return loss
    return focal_loss_fixed
 
class FocalLoss(losses.Loss):
 
    def __init__(self,gamma=2.0,alpha=0.25):
        self.gamma = gamma
        self.alpha = alpha
 
    def call(self,y_true,y_pred):
 
        pt_1 = tf.where(tf.equal(y_true, 1), y_pred, tf.ones_like(y_pred))
        pt_0 = tf.where(tf.equal(y_true, 0), y_pred, tf.zeros_like(y_pred))
        loss = -tf.sum(self.alpha * tf.pow(1. - pt_1, self.gamma) * tf.log(1e-07+pt_1)) \
           -tf.sum((1-self.alpha) * tf.pow( pt_0, self.gamma) * tf.log(1. - pt_0 + 1e-07))
        return loss

 

參考:

開源電子書地址:https://lyhue1991.github.io/eat_tensorflow2_in_30_days/

GitHub 項目地址:https://github.com/lyhue1991/eat_tensorflow2_in_30_days


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM