tf.while_loop


while 循環

def while_loop(cond,          ### 一個函數,負責判斷循環是否進行
               body,          ### 一個函數,循環體,更新變量
               loop_vars,     ### 初始循環變量,可以是多個,這些變量是 cond、body 的輸入 和輸出
               shape_invariants=None,
               parallel_iterations=10,
               back_prop=True,
               swap_memory=False,
               name=None,
               maximum_iterations=None,
               return_same_structure=False):

返回 迭代后的 loop_vars

 

def cond(i, n):
    return i < n

def body(i, n):
    i = i + 1
    return i, n

i = tf.get_variable("ii", dtype=tf.int32, shape=[], initializer=tf.ones_initializer())
# i = 1                 # 也可以
# i = tf.constant(1)    # 也可以
n = tf.constant(10)
i, n = tf.while_loop(cond, body, [i, n])
with tf.Session() as sess:
    tf.global_variables_initializer().run()
    res = sess.run([i, n])
    print(res)      # [10, 10]

注意:cond 和 body 的輸入和輸出要相同,且等於 loop_vars,即使在函數中沒有用到全部的 loop_vars,也要做為輸入和輸出

 

 

參考資料:


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM