TensorFlow里面損失函數


損失算法的選取

損失函數的選取取決於輸入標簽數據的類型:

  • 如果輸入的是實數、無界的值,損失函數使用平方差;
  • 如果輸入標簽是位矢量(分類標志),使用交叉熵會更適合。

1.均值平方差

 

 

在TensorFlow沒有單獨的MSE函數,不過由於公式比較簡單,往往開發者都會自己組合,而且也可以寫出n種寫法,例如:

MSE=tf.reduce_mean(tf.pow(tf.sub(logits, outputs), 2.0))
MSE=tf.reduce_mean(tf.square(tf.sub(logits, outputs)))
MSE=tf.reduce_mean(tf.square(logits- outputs))

代碼中logits代表標簽值,outputs代表預測值

2.交叉熵

交叉熵(crossentropy)也是loss算法的一種,一般用在分類問題上,表達的意識為預測輸入樣本屬於某一類的概率 。其表達式如下,其中y代表真實值分類(0或1),a代表預測值。

 

 

在TensorFlow中常見的交叉熵函數有:

  • Sigmoid交叉熵;
  • softmax交叉熵;
  • Sparse交叉熵;
  • 加權Sigmoid交叉熵。

圖:在TensorFlow里常用的損失函數如表所示。

 

當然,也可以像MSE那樣使用自己組合的公式計算交叉熵,舉例,對於softmax后的結果logits我們可以對其使用公式-tf.reduce_sum(labels*tf.log(logits),1),就等同於softmax_cross_entropy_with_logits得到的結果。

 


import tensorflow as tf

labels = [[0, 0, 1], [0, 1, 0]]
logits = [[2, 0.5, 6], [0.1, 0, 3]]
logits_scaled = tf.nn.softmax(logits)
logits_scaled2 = tf.nn.softmax(logits_scaled)

result1 = tf.nn.softmax_cross_entropy_with_logits(labels=labels, logits=logits)
result2 = tf.nn.softmax_cross_entropy_with_logits(labels=labels, logits=logits_scaled)
result3 = -tf.reduce_sum(labels * tf.log(logits_scaled), 1)

with tf.Session() as sess:
print("scaled=", sess.run(logits_scaled))
print("scaled2=", sess.run(logits_scaled2))
# 經過第二次的softmax后,分布概率會有變化

print("rel1=", sess.run(result1), "\n") # 正確的方式
print("rel2=", sess.run(result2), "\n")
# 如果將softmax變換完的值放進去會,就相當於算第二次softmax的loss,所以會出錯
print("rel3=", sess.run(result3))

 

運行上面代碼,輸出結果如下:

scaled= [[ 0.01791432 0.00399722 0.97808844]
[ 0.04980332 0.04506391 0.90513283]]
scaled2= [[ 0.21747023 0.21446465 0.56806517]
[ 0.2300214 0.22893383 0.54104471]]
rel1= [ 0.02215516 3.09967351]
rel2= [ 0.56551915 1.47432232]
rel3= [ 0.02215518 3.09967351]

 

下面開始驗證下前面所說的實驗:

  • 比較scaled和scaled2可以看到:經過第二次的softmax后,分布概率會有變化,而scaled才是我們真實轉化的softmax值。

  • 比較rel1和rel2可以看到:傳入softmax_cross_entropy_with_logits的logits是不需要進行softmax的。如果將softmax后的值scaled傳入softmax_cross_entropy_with_logits就相當於進行了兩次的softmax轉換。



 





免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM