10 tensorflow在循環體中用tf.print輸出節點內容


代碼

i=tf.constant(0,dtype=tf.int32)
batch_len=tf.constant(10,dtype=tf.int32)
loop_cond = lambda a,b: tf.less(a,batch_len)
#yy=tf.Print(batch_len,[batch_len],"batch_len:")
yy=tf.constant(0)
loop_vars=[i,yy]
def _recurrence(i,yy):
    c=tf.constant(2,dtype=tf.int32)
    x=tf.multiply(i,c)
    print_info=tf.Print(x,[x],"x:")
    yy=yy+print_info
    i=tf.add(i,1)
    return i,yy
i,yy=tf.while_loop(loop_cond,_recurrence,loop_vars,parallel_iterations=1)#可以批處理
sess = tf.Session()
sess.run(yy)

輸出信息

為什么會這樣,因為執行sess.run(yy)的時候,會有數據流過循環體中的所有tf.Print節點,此時就會執行tf.Print中指定的輸出。最關鍵的操作就是yy=yy+print_info

存在的問題(與Spyder有關)

在spyder中使用調試模式的時候,無法輸出上面的信息。

上面的代碼是使用‘python 測試程序__在循環中使用tf.print.py’的方式在命令行執行才會輸出。

如何不斷的輸出tf.Print信息

除了上述使用yy=yy+print_info。

如果print_info是這樣的,比如:

print_info=tf.Print(constructionErrorMatrix,[constructionErrorMatrix],"constructionErrorMatrix:")#專門為了調試用,輸出相關信息。
tfPrint=tfPrint+tf.to_int32(print_info[0])#一種不斷輸出tf.Print的方式,注意tf.Print的返回值。
constructionErrorMatrix是一個(?,)類型的float64 Tensor。我們可以用上述代碼,繼續進行tf.Print的輸出。

此外,tf.Print中的第二個參數[]中放入的內容,也必須是能夠轉為Tensor。否則會提示

TypeError: Tensors in list passed to 'data' of 'Print' Op have types [<NOT CONVERTIBLE TO TENSOR>] that are invalid.

比如,一個Tensor的shape中如果有“?,就不能轉換為Tensor。對於這種不能Tensor,我們不能用get_shape()[i].value去獲取?的維度,但是我們可以用tf.shape獲取有數據流入以后的動態維度。就是?最終確定的維度。

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM