关于tf.cond函数中“正确”与“错误”函数中的普通python语句始终执行的问题


 1 import tensorflow as tf
 2 import numpy as np
 3 x = tf.constant(2)
 4 y = tf.constant(3)
 5 global mask0
 6 mask0 = np.ones(shape=[1, 6, 1], dtype='float32')
 7 
 8 
 9 def f1():
10     mask0[0, 0, 0] = 0.
11     print(1)
12     return tf.multiply(x, 17)
13 
14 
15 def f2():
16     print(2)
17     return tf.add(y, 23)
18 
19 
20 test = tf.cond(pred=tf.equal(x, y), true_fn=f1, false_fn=f2)
21 
22 with tf.Session() as sess:
23     print(sess.run(test))
24     print(mask0)

执行结果:

 1 1
 2 2
 3 2018-08-01 10:25:46.443812: I T:\src\github\tensorflow\tensorflow\core\platform\cpu_feature_guard.cc:141] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2
 4 26
 5 [[[0.]
 6   [1.]
 7   [1.]
 8   [1.]
 9   [1.]
10   [1.]]]
11 
12 进程已结束,退出代码0

执行以上代码发现,即使判断为假,执行了f2,f1的值未返回,但f1中的语句依然执行了。最后发现,“正确”函数和“错误”函数都会在判断语句中被调用始终按照先“正确”函数,后“错误”函数的顺序执行一遍(经过将两个函数定义位置调换,测试确实如此,始终先打印“正确”函数中的字符),其中的非张量数据流不受流程控制,tensorflow的流程控制只对tensor张量有效。最后由于本人需要完成的功能,采取了在两个函数中只写张量运算的代码:

 1 import tensorflow as tf
 2 x = tf.constant(2)
 3 y = tf.constant(3)
 4 
 5 mask0 = tf.ones(shape=[1, 5, 1], dtype='float32')
 6 
 7 
 8 def f2():
 9     var0 = tf.one_hot(indices=[3], depth=5, axis=0)
10     print(2)
11     return var0
12 
13 
14 def f1():
15     var0 = tf.one_hot(indices=[1], depth=4, axis=0)
16     print(1)
17     return var0
18 
19 for h1 in range(1):
20     var0 = tf.cond(pred=tf.equal(x, y), true_fn=f1, false_fn=f2)
21     mask0 = mask0 - var0
22 
23 with tf.Session() as sess:
24     print(sess.run(mask0))

运行结果:

 1 1
 2 2
 3 2018-08-01 10:44:25.417305: I T:\src\github\tensorflow\tensorflow\core\platform\cpu_feature_guard.cc:141] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2
 4 [[[1.]
 5   [1.]
 6   [1.]
 7   [0.]
 8   [1.]]]
 9 
10 进程已结束,退出代码0

 

注意这里tf.cond的两个函数返回数据维度不同也没有报错,经测试,当数据类型不同时会报错。普通python语句和之前一样按先“正确”后“错误”执行,tensorflow的数据流得到控制,只执行了“错误”语句。

 


免责声明!

本站转载的文章为个人学习借鉴使用,本站对版权不负任何法律责任。如果侵犯了您的隐私权益,请联系本站邮箱yoyou2525@163.com删除。



 
粤ICP备18138465号  © 2018-2025 CODEPRJ.COM