softmax_cross_entropy_with_logits


softmax_cross_entropy_with_logits

覺得有用的話,歡迎一起討論相互學習~


我的微博我的github我的B站

函數定義

def softmax_cross_entropy_with_logits(_sentinel=None,  # pylint: disable=invalid-name
                                      labels=None, logits=None,
                                      dim=-1, name=None)

解釋

  • 這個函數的作用是計算 logits 經 softmax 函數激活之后的交叉熵。
  • 對於每個獨立的分類任務,這個函數是去度量概率誤差。比如,在 CIFAR-10 數據集上面,每張圖片只有唯一一個分類標簽:一張圖可能是一只狗或者一輛卡車,但絕對不可能兩者都在一張圖中。(這也是和 tf.nn.sigmoid_cross_entropy_with_logits(logits, targets, name=None)這個API的區別)

說明

  1. 輸入API的數據 logits 不能進行縮放,因為在這個API的執行中會進行 softmax 計算,如果 logits 進行了縮放,那么會影響計算正確率。
  2. 不要調用這個API去計算 softmax 的值,因為這個API最終輸出的結果並不是經過 softmax 函數的值。
  3. logits 和 labels 必須有相同的數據維度 [batch_size, num_classes],和相同的數據類型 float32 或者 float64 。
  4. 它適用於每個類別相互獨立且排斥的情況,一幅圖只能屬於一類,而不能同時包含一條狗和一只大象.

示例代碼

import tensorflow as tf

input_data = tf.Variable([[0.2, 0.1, 0.9], [0.3, 0.4, 0.6]], dtype=tf.float32)
output = tf.nn.softmax_cross_entropy_with_logits(logits=input_data, labels=[[0, 0, 1], [1, 0, 0]])
with tf.Session() as sess:
    init = tf.global_variables_initializer()
    sess.run(init)
    print(sess.run(output))
    # [1.36573195]

參數

輸入參數

_sentinel: 這個參數一般情況不使用,直接設置為None就好

logits: 一個沒有縮放的對數張量。labels和logits具有相同的數據類型(type)和尺寸(shape)

labels: 每一行 labels[i] 必須是一個有效的概率分布值。

name: 為這個操作取個名字。

輸出參數

一個 Tensor ,數據維度是一維的,長度是 batch_size,數據類型都和 logits 相同。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM