weighted_cross_entropy_with_logits


weighted_cross_entropy_with_logits

覺得有用的話,歡迎一起討論相互學習~


我的微博我的github我的B站

weighted_cross_entropy_with_logits(targets, logits, pos_weight, name=None):

此函數功能以及計算方式基本與tf_nn_sigmoid_cross_entropy_with_logits差不多,但是加上了權重的功能,是計算具有權重的sigmoid交叉熵函數

計算方法 :

\[pos_weight*targets * -log(sigmoid(logits)) + (1 - targets) * -log(1 - sigmoid(logits)) \]

官方文檔定義及推導過程:

通常的cross-entropy交叉熵函數定義如下:

\[targets * -log(sigmoid(logits)) + (1 - targets) * -log(1 - sigmoid(logits))\]

對於加了權值pos_weight的交叉熵函數:

\[ targets * -log(sigmoid(logits)) * pos_weight + (1 - targets) * -log(1 - sigmoid(logits))\]

現在我們使用 x = logits, z = targets, q = pos_weight的代數式

  The loss is:

        qz * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
      = qz * -log(1 / (1 + exp(-x))) + (1 - z) * -log(exp(-x) / (1 + exp(-x)))
      = qz * log(1 + exp(-x)) + (1 - z) * (-log(exp(-x)) + log(1 + exp(-x)))
      = qz * log(1 + exp(-x)) + (1 - z) * (x + log(1 + exp(-x))
      = (1 - z) * x + (qz +  1 - z) * log(1 + exp(-x))
      = (1 - z) * x + (1 + (q - 1) * z) * log(1 + exp(-x))

我們把l = (1 + (q - 1) * z), 來確保穩定性並且比避免溢出,公式為:

\[(1 - z) * x + l * (log(1 + exp(-abs(x))) + max(-x, 0)) \]

logits and targets 必須要有相同的數據類型和shape.

參數:

_sentinel:本質上是不用的參數,不用填

targets:一個和logits具有相同的數據類型(type)和尺寸形狀(shape)的張量(tensor)

shape:[batch_size,num_classes],單樣本是[num_classes]

logits:一個數據類型(type)是float32或float64的張量

pos_weight:正樣本的一個系數

name:操作的名字,可填可不填

實例代碼

import numpy as np
import tensorflow as tf

input_data = tf.Variable(np.random.rand(3, 3), dtype=tf.float32)
# np.random.rand()傳入一個shape,返回一個在[0,1)區間符合均勻分布的array

output = tf.nn.weighted_cross_entropy_with_logits(logits=input_data,
                                                  targets=[[1.0, 0.0, 0.0], [0.0, 0.0, 1.0], [0.0, 0.0, 1.0]],
                                                  pos_weight=2.0)
with tf.Session() as sess:
    init = tf.global_variables_initializer()
    sess.run(init)
    print(sess.run(output))
# [[ 1.04947078  0.89594436  0.92146152]
#  [ 0.70252579  1.00673866  1.08856964]
#  [ 1.07195592  1.18525708  1.04106498]]


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM