為什么要寫 tf.Graph().as_default()


首先,去tensorflow官網API上查詢 tf.Graph() 會看到如下圖所示的內容:

總體含義是說:

tf.Graph() 表示實例化了一個類,一個用於 tensorflow 計算和表示用的數據流圖,通俗來講就是:在代碼中添加的操作(畫中的結點)和數據(畫中的線條)都是畫在紙上的“畫”,而圖就是呈現這些畫的紙,你可以利用很多線程生成很多張圖,但是默認圖就只有一張。

tf.Graph().as_default() 表示將這個類實例,也就是新生成的圖作為整個 tensorflow 運行環境的默認圖,如果只有一個主線程不寫也沒有關系,tensorflow 里面已經存好了一張默認圖,可以使用tf.get_default_graph() 來調用(顯示這張默認紙),當你有多個線程就可以創造多個tf.Graph(),就是你可以有一個畫圖本,有很多張圖紙,這時候就會有一個默認圖的概念了。

具體的示例代碼如下,和圖中的一樣:

 1 import tensorflow as tf
 2 c=tf.constant(4.0)
 3 assert c.graph is tf.get_default_graph() #看看主程序中新建的一個變量是不是在默認圖里
 4 g=tf.Graph()
 5 with g.as_default():
 6     c=tf.constant(30.0)
 7     assert c.graph is g
 8 '''
 9 最終結果是沒有報錯
10 '''

結語:以上內容純屬自己理解,如有不當之處還請指正;歡迎轉載,標明出處。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM