softmax_cross_entropy_with_logits
覺得有用的話,歡迎一起討論相互學習~
函數定義
def softmax_cross_entropy_with_logits(_sentinel=None, # pylint: disable=invalid-name
labels=None, logits=None,
dim=-1, name=None)
解釋
- 這個函數的作用是計算 logits 經 softmax 函數激活之后的交叉熵。
- 對於每個獨立的分類任務,這個函數是去度量概率誤差。比如,在 CIFAR-10 數據集上面,每張圖片只有唯一一個分類標簽:一張圖可能是一只狗或者一輛卡車,但絕對不可能兩者都在一張圖中。(這也是和 tf.nn.sigmoid_cross_entropy_with_logits(logits, targets, name=None)這個API的區別)
說明
- 輸入API的數據 logits 不能進行縮放,因為在這個API的執行中會進行 softmax 計算,如果 logits 進行了縮放,那么會影響計算正確率。
- 不要調用這個API去計算 softmax 的值,因為這個API最終輸出的結果並不是經過 softmax 函數的值。
- logits 和 labels 必須有相同的數據維度 [batch_size, num_classes],和相同的數據類型 float32 或者 float64 。
- 它適用於每個類別相互獨立且排斥的情況,一幅圖只能屬於一類,而不能同時包含一條狗和一只大象.
示例代碼
import tensorflow as tf
input_data = tf.Variable([[0.2, 0.1, 0.9], [0.3, 0.4, 0.6]], dtype=tf.float32)
output = tf.nn.softmax_cross_entropy_with_logits(logits=input_data, labels=[[0, 0, 1], [1, 0, 0]])
with tf.Session() as sess:
init = tf.global_variables_initializer()
sess.run(init)
print(sess.run(output))
# [1.36573195]
參數
輸入參數
_sentinel: 這個參數一般情況不使用,直接設置為None就好
logits: 一個沒有縮放的對數張量。labels和logits具有相同的數據類型(type)和尺寸(shape)
labels: 每一行 labels[i] 必須是一個有效的概率分布值。
name: 為這個操作取個名字。
輸出參數
一個 Tensor ,數據維度是一維的,長度是 batch_size,數據類型都和 logits 相同。




