tensorflow基礎【4】-計算圖 graph


tensorflow,tensor就是數據,flow就是流,tensorflow就是數據流

tensorflow 是一個用計算圖的形式來表示計算的編程系統,所有的數據和計算都會被轉化成計算圖上的一個節點,節點之間的邊就是數據流(數據流動的軌跡)。

 

計算圖的使用

1. 建立節點

2. 執行計算

 

計算圖有兩種形式

默認的計算圖

tf 維護一個默認的計算圖,

get_default_graph:獲取默認計算圖

graph:獲取節點所屬計算圖

import tensorflow as tf

a = tf.constant([1., 2.], name = 'a')
b = tf.constant([2., 3.], name = 'b')
result = a + b

print(a.graph is tf.get_default_graph())            # True

數據本身就是節點,該節點的 graph 就是默認計算圖

 

自定義計算圖

tf.Graph 可以生成新的計算圖,不同計算圖之間的數據和計算不能共享

## g1
g1 = tf.Graph()
with g1.as_default():
    # 在計算圖 g1 中定義變量 “v” ,並設置初始值為 0。
    v = tf.get_variable("v", [1], initializer = tf.zeros_initializer()) # 設置初始值為0,shape 為 1

## g2
g2 = tf.Graph()
with g2.as_default():
    # 在計算圖 g2 中定義變量 “v” ,並設置初始值為 1。
    v = tf.get_variable("v", [1], initializer = tf.ones_initializer()) # 設置初始值為1


# 在計算圖 g1 中讀取變量“v” 的取值
with tf.Session(graph = g1) as sess:
    tf.global_variables_initializer().run()
    with tf.variable_scope("", reuse=True):
        print(sess.run(tf.get_variable("v")))               # [0.]

# 在計算圖 g2 中讀取變量“v” 的取值
with tf.Session(graph = g2) as sess:
    tf.global_variables_initializer().run()
    with tf.variable_scope("", reuse=True):
        print(sess.run(tf.get_variable("v")))               # [1.]

## g3
g3 = tf.Graph()
with g3.as_default():
    # 在計算圖 g2 中定義變量 “v” ,並設置初始值為 1。
    v = tf.get_variable("v2", [1], initializer = tf.ones_initializer()) # 設置初始值為1


# 在計算圖 g1 中讀取變量“v” 的取值
with tf.Session(graph = g3) as sess:
    tf.global_variables_initializer().run()
    with tf.variable_scope("", reuse=True):
        print(sess.run(tf.get_variable("v2")))               # [1.]
        print(sess.run(tf.get_variable("v")))                # 報錯 Variable v does not exist

可以看到 g3 無法調用 g2 中的變量v

計算圖可以用來隔離張量和計算

 

計算圖的操作

保存

g1 = tf.Graph()
with g1.as_default():
    # 需要加上名稱,在讀取pb文件的時候,是通過name和下標來取得對應的tensor的
    c1 = tf.constant(4.0, name='c1')

with tf.Session(graph=g1) as sess1:
    print(sess1.run(c1))                        # 4.0


# g1的圖定義,包含pb的path, pb文件名,是否是文本默認False
tf.train.write_graph(g1.as_graph_def(),'.','graph.pb',False)

讀取

import tensorflow as tf#load graph
with tf.gfile.FastGFile("./graph.pb",'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    tf.import_graph_def(graph_def, name='')

sess = tf.Session()
c1_tensor = sess.graph.get_tensor_by_name("c1:0")
c1 = sess.run(c1_tensor)
print(c1)                       # 4.0

穿插調用

g1 = tf.Graph()
with g1.as_default():
    # 聲明的變量有名稱是一個好的習慣,方便以后使用
    c1 = tf.constant(4.0, name="c1")

g2 = tf.Graph()
with g2.as_default():
    c2 = tf.constant(20.0, name="c2")

with tf.Session(graph=g2) as sess1:
    # 通過名稱和下標來得到相應的值
    c1_list = tf.import_graph_def(g1.as_graph_def(), return_elements = ["c1:0"], name = '')
    print(sess1.run(c1_list[0]+c2))             # 24.0

 

指定計算圖的運行設備

g = tf.Graph()
# 指定計算運行的設備
with g.device('/gpu:0'):
    result = a + b

 

計算圖資源管理

在一個計算圖中,可以通過集合來管理不同的資源。

比如通過 tf.add_to_collection 將資源加入一個或多個集合中,然后通過 tf.get_collection 獲取一個集合里的所有資源

 

 

參考資料:

https://www.cnblogs.com/q735613050/p/7632792.html

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM