tf中assign()函數可用於對變量進行更新包括變量的value和shape。
涉及以下函數:
- tf.assign(ref, value, validate_shape = None, use_locking = None, name=None)
- tf.assign_add(ref, value, use_locking = None, name=None)
- tf.assign_sub(ref, value, use_locking = None, name=None)
- tf.variable.assign(value, use_locking=False)
- tf.variable.assign_add(delta, use_locking=False)
- tf.variable.assign_sub(delta, use_locking=False)
這6個函數本質上是一樣的,都是用來對變量值進行更新,其中tf.assign還可以更新變量的shape。
解釋一下它們的意思:tf.assign是用value的值賦給ref,這種賦值會覆蓋掉原來的值,即更新而不會創建一個新的tensor。tf.assign_add相當於ref=ref+value來更新ref。tf.assign_sub相當於ref=ref-value來更新ref。tf.variable.assign相當於tf.assign(ref, value)。同理tf.variable.assign_add和tf.variable.assign_sub。
下面對tf.assign函數進行詳細說明。
tf.assign(ref, value, validate_shape = None, use_locking = None, name=None)
args:
- ref:一個可變的張量。應該來自變量節點,節點可能未初始化,參考下面的例子。
- value:張量。必須具有與 ref 相同的類型。是要分配給變量的值。
- validate_shape:一個可選的 bool。默認為 True。如果為 true, 則操作將驗證 "value" 的形狀是否與分配給的張量的形狀相匹配;如果為 false, "ref" 將對 "值" 的形狀進行引用。
- use_locking:一個可選的 bool。默認為 True。如果為 True, 則分配將受鎖保護;否則, 該行為是未定義的, 但可能會顯示較少的爭用。
- name:操作的名稱(可選)。
返回:
一個在賦值完成后將保留 "ref" 新值的張量。
現在舉三個例子,說明三個問題:
例子1:assign操作會初始化相關的節點,並不需要tf.global_variables_initializer()初始化,但是並非所有的節點都會被初始化。
#-*-coding:utf-8-*- import tensorflow as tf import numpy as np weights=tf.Variable(tf.random_normal([1,2],stddev=0.35),name="weights") biases=tf.Variable(tf.zeros([3]),name="biases") x_data = np.float32(np.random.rand(2, 3)) # 隨機輸入2行3列的數據 y = tf.matmul(weights, x_data) + biases update=tf.assign(weights,tf.random_normal([1,2],stddev=0.50))#正確 #update=weights.assign(tf.random_normal([1,2],stddev=0.50))#正確,和上句意義相同 #init=tf.global_variables_initializer() with tf.Session() as sess: #sess.run(init) for _ in range(2): sess.run(update) print(sess.run(weights))#正確,因為assign操作會初始化相關的節點 print(sess.run(y))#錯誤,因為使用了未初始化的biases變量
例子2:tf.assign()操作可以改變變量的shape,只需要令參數validate_shape=False,默認為True。
#-*-coding:utf-8-*- import tensorflow as tf x = tf.Variable(0) y = tf.assign(x, [5,2], validate_shape=False) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) print (sess.run(x))#輸出0 print (sess.run(y))#輸出[5 2] print (sess.run(x))#輸出[5 2]
例子3:assign都會在圖中產生額外的操作,可用tf.Variable.load(value, session)實現從圖外賦值不產生額外的操作。
#-*-coding:utf-8-*- import tensorflow as tf x = tf.Variable(0) sess = tf.Session() sess.run(tf.global_variables_initializer()) print(sess.run(x)) # 輸出 0 x.load(5, sess) print(sess.run(x)) # 輸出 5
