正確理解TensorFlow中的logits


softmax是在一個n分類問題中,輸入一個n維的logits向量,輸出一個n維概率向量,其物理意義是logits代表的物體屬於各類的概率。即softmax的輸出是一個n維的one_hot_prediction。
softmax_cross_entropy_with_logits輸出的是一個batch_size維的向量,這個向量的每一維表示每一個sample的one_hot_label和one_hot_prediction 之間的交叉熵。
softmax_cross_entropy的輸出是前述batch_size維的向量的均值。

總結觀點:

logits與 softmax都屬於在輸出層的內容,

logits = tf.matmul(X, W) + bias

再對logits做歸一化處理,就用到了softmax:

Y_pred = tf.nn.softmax(logits,name='Y_pred')

——————————————————————

Unscaled log probabilities of shape [d_0, d_1, ..., d_{r-1}, num_classes] and dtype float32 or float64.

可以理解logits ——【batchsize,class_num】是未進入softmax的概率,一般是全連接層的輸出,softmax的輸入

注,一般將沒有加激活函數的稱為Logits,加了softmax后稱為Probabilities,經過softmax后,有把最大值放大的過程,相當於把強的變得更強,把弱的變得更弱。

 

我正想通過tensorflow API文檔在這里。在tensorflow文檔中,他們使用了一個叫做關鍵字logits。它是什么?API文檔中的很多方法都是這樣寫的

tf.nn.softmax(logits, name=None)

如果寫的是logits只有這些Tensors,為什么要保留一個不同的名字logits?

另一件事是有兩種方法我不能區分。他們是

tf.nn.softmax(logits, name=None)
tf.nn.softmax_cross_entropy_with_logits(logits, labels, name=None)

他們之間有什么不同?文檔對我不明確。我知道是什么tf.nn.softmax。但不是其他。一個例子會非常有用。

假設您有兩個張量,其中y_hat包含每個類的計算得分(例如,從y = W * x + b),並y_true包含一個熱點編碼的真實標簽。

y_hat  = ... # Predicted label, e.g. y = tf.matmul(X, W) + b
y_true = ... # True label, one-hot encoded

如果您將分數解釋為y_hat非標准化的日志概率,那么它們就是logits。

另外,以這種方式計算的總交叉熵損失:

y_hat_softmax = tf.nn.softmax(y_hat)
total_loss = tf.reduce_mean(-tf.reduce_sum(y_true * tf.log(y_hat_softmax), [1]))

本質上等價於用函數計算的總交叉熵損失softmax_cross_entropy_with_logits():

total_loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(y_hat, y_true))

在神經網絡的輸出層中,您可能會計算一個數組,其中包含每個訓練實例的類分數,例如來自計算y_hat = W*x + b。作為一個例子,下面我創建了y_hat一個2×3數組,其中行對應於訓練實例,列對應於類。所以這里有2個訓練實例和3個類別。

import tensorflow as tf
import numpy as np
sess = tf.Session() # Create example y_hat. 
y_hat = tf.convert_to_tensor(np.array([[0.5, 1.5, 0.1],[2.2, 1.3, 1.7]]))
sess.run(y_hat) # array([[ 0.5, 1.5, 0.1], # [ 2.2, 1.3, 1.7]])

請注意,這些值沒有標准化(即每一行的和不等於1)。為了對它們進行歸一化,我們可以應用softmax函數,它將輸入解釋為非歸一化對數概率(又名logits)並輸出歸一化的線性概率。

y_hat_softmax = tf.nn.softmax(y_hat)
sess.run(y_hat_softmax)
# array([[ 0.227863  ,  0.61939586,  0.15274114],
#        [ 0.49674623,  0.20196195,  0.30129182]])

充分理解softmax輸出的含義非常重要。下面我列出了一張更清楚地表示上面輸出的表格。可以看出,例如,訓練實例1為“2類”的概率為0.619。每個訓練實例的類概率被歸一化,所以每行的總和為1.0。

                      Pr(Class 1)  Pr(Class 2)  Pr(Class 3)
                    ,--------------------------------------
Training instance 1 | 0.227863   | 0.61939586 | 0.15274114
Training instance 2 | 0.49674623 | 0.20196195 | 0.30129182

所以現在我們有每個訓練實例的類概率,我們可以在每個行的argmax()中生成最終的分類。從上面,我們可以生成訓練實例1屬於“2類”,訓練實例2屬於“1類”。

這些分類是否正確?我們需要根據訓練集中的真實標簽進行測量。您將需要一個熱點編碼y_true數組,其中行又是訓練實例,列是類。下面我創建了一個示例y_trueone-hot數組,其中訓練實例1的真實標簽為“Class 2”,訓練實例2的真實標簽為“Class 3”。

y_true = tf.convert_to_tensor(np.array([[0.0, 1.0, 0.0],[0.0, 0.0, 1.0]]))
sess.run(y_true)
# array([[ 0.,  1.,  0.],
#        [ 0.,  0.,  1.]])

概率分布是否y_hat_softmax接近概率分布y_true?我們可以使用交叉熵損失來衡量錯誤。

Formula for cross-entropy loss

 

我們可以逐行計算交叉熵損失並查看結果。下面我們可以看到,訓練實例1損失了0.479,而訓練實例2損失了1.200。這個結果是有道理的,因為在我們上面的例子中y_hat_softmax,訓練實例1的最高概率是“類2”,它與訓練實例1匹配y_true; 然而,訓練實例2的預測顯示“1類”的最高概率,其與真實類“3類”不匹配。

loss_per_instance_1 = -tf.reduce_sum(y_true * tf.log(y_hat_softmax), reduction_indices=[1])
sess.run(loss_per_instance_1)
# array([ 0.4790107 ,  1.19967598])

我們真正想要的是所有培訓實例的全部損失。所以我們可以計算:

total_loss_1 = tf.reduce_mean(-tf.reduce_sum(y_true * tf.log(y_hat_softmax), reduction_indices=[1]))
sess.run(total_loss_1)
# 0.83934333897877944

我們可以用tf.nn.softmax_cross_entropy_with_logits()函數來計算總的交叉熵損失,如下所示。

loss_per_instance_2 = tf.nn.softmax_cross_entropy_with_logits(y_hat, y_true)
sess.run(loss_per_instance_2)
# array([ 0.4790107 ,  1.19967598])

total_loss_2 = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(y_hat, y_true))
sess.run(total_loss_2)
# 0.83934333897877922

請注意,total_loss_1並total_loss_2產生基本相同的結果,在最后一位數字中有一些小的差異。但是,你可以使用第二種方法:它只需要少一行代碼,並累積更少的數字錯誤,因為softmax是在你內部完成的softmax_cross_entropy_with_logits()。


參考:https://www.jianshu.com/p/6c9b0cc6978b


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM