關於tf.cond函數中“正確”與“錯誤”函數中的普通python語句始終執行的問題


 1 import tensorflow as tf
 2 import numpy as np
 3 x = tf.constant(2)
 4 y = tf.constant(3)
 5 global mask0
 6 mask0 = np.ones(shape=[1, 6, 1], dtype='float32')
 7 
 8 
 9 def f1():
10     mask0[0, 0, 0] = 0.
11     print(1)
12     return tf.multiply(x, 17)
13 
14 
15 def f2():
16     print(2)
17     return tf.add(y, 23)
18 
19 
20 test = tf.cond(pred=tf.equal(x, y), true_fn=f1, false_fn=f2)
21 
22 with tf.Session() as sess:
23     print(sess.run(test))
24     print(mask0)

執行結果:

 1 1
 2 2
 3 2018-08-01 10:25:46.443812: I T:\src\github\tensorflow\tensorflow\core\platform\cpu_feature_guard.cc:141] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2
 4 26
 5 [[[0.]
 6   [1.]
 7   [1.]
 8   [1.]
 9   [1.]
10   [1.]]]
11 
12 進程已結束,退出代碼0

執行以上代碼發現,即使判斷為假,執行了f2,f1的值未返回,但f1中的語句依然執行了。最后發現,“正確”函數和“錯誤”函數都會在判斷語句中被調用始終按照先“正確”函數,后“錯誤”函數的順序執行一遍(經過將兩個函數定義位置調換,測試確實如此,始終先打印“正確”函數中的字符),其中的非張量數據流不受流程控制,tensorflow的流程控制只對tensor張量有效。最后由於本人需要完成的功能,采取了在兩個函數中只寫張量運算的代碼:

 1 import tensorflow as tf
 2 x = tf.constant(2)
 3 y = tf.constant(3)
 4 
 5 mask0 = tf.ones(shape=[1, 5, 1], dtype='float32')
 6 
 7 
 8 def f2():
 9     var0 = tf.one_hot(indices=[3], depth=5, axis=0)
10     print(2)
11     return var0
12 
13 
14 def f1():
15     var0 = tf.one_hot(indices=[1], depth=4, axis=0)
16     print(1)
17     return var0
18 
19 for h1 in range(1):
20     var0 = tf.cond(pred=tf.equal(x, y), true_fn=f1, false_fn=f2)
21     mask0 = mask0 - var0
22 
23 with tf.Session() as sess:
24     print(sess.run(mask0))

運行結果:

 1 1
 2 2
 3 2018-08-01 10:44:25.417305: I T:\src\github\tensorflow\tensorflow\core\platform\cpu_feature_guard.cc:141] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2
 4 [[[1.]
 5   [1.]
 6   [1.]
 7   [0.]
 8   [1.]]]
 9 
10 進程已結束,退出代碼0

 

注意這里tf.cond的兩個函數返回數據維度不同也沒有報錯,經測試,當數據類型不同時會報錯。普通python語句和之前一樣按先“正確”后“錯誤”執行,tensorflow的數據流得到控制,只執行了“錯誤”語句。

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM