assign()函數


tf中assign()函數可用於對變量進行更新包括變量的value和shape。

涉及以下函數:

  • tf.assign(ref, value, validate_shape = None, use_locking = None, name=None)
  • tf.assign_add(ref, value, use_locking = None, name=None)
  • tf.assign_sub(ref, value, use_locking = None, name=None)
  • tf.variable.assign(value, use_locking=False)
  • tf.variable.assign_add(delta, use_locking=False)
  • tf.variable.assign_sub(delta, use_locking=False)

這6個函數本質上是一樣的,都是用來對變量值進行更新,其中tf.assign還可以更新變量的shape。

解釋一下它們的意思:tf.assign是用value的值賦給ref,這種賦值會覆蓋掉原來的值,即更新而不會創建一個新的tensor。tf.assign_add相當於ref=ref+value來更新ref。tf.assign_sub相當於ref=ref-value來更新ref。tf.variable.assign相當於tf.assign(ref, value)。同理tf.variable.assign_add和tf.variable.assign_sub。

下面對tf.assign函數進行詳細說明。

tf.assign(ref, value, validate_shape = None, use_locking = None, name=None)

args:

  • ref:一個可變的張量。應該來自變量節點,節點可能未初始化,參考下面的例子。
  • value:張量。必須具有與 ref 相同的類型。是要分配給變量的值。
  • validate_shape:一個可選的 bool。默認為 True。如果為 true, 則操作將驗證 "value" 的形狀是否與分配給的張量的形狀相匹配;如果為 false, "ref" 將對 "值" 的形狀進行引用。
  • use_locking:一個可選的 bool。默認為 True。如果為 True, 則分配將受鎖保護;否則, 該行為是未定義的, 但可能會顯示較少的爭用。
  • name:操作的名稱(可選)。

返回:

一個在賦值完成后將保留 "ref" 新值的張量。

現在舉三個例子,說明三個問題:

例子1:assign操作會初始化相關的節點,並不需要tf.global_variables_initializer()初始化,但是並非所有的節點都會被初始化。

#-*-coding:utf-8-*-
import tensorflow as tf
import numpy as np

weights=tf.Variable(tf.random_normal([1,2],stddev=0.35),name="weights")
biases=tf.Variable(tf.zeros([3]),name="biases")
x_data = np.float32(np.random.rand(2, 3)) # 隨機輸入2行3列的數據

y = tf.matmul(weights, x_data) + biases

update=tf.assign(weights,tf.random_normal([1,2],stddev=0.50))#正確
#update=weights.assign(tf.random_normal([1,2],stddev=0.50))#正確,和上句意義相同

#init=tf.global_variables_initializer()

with tf.Session() as sess:
    #sess.run(init)
    for _ in range(2):
        sess.run(update)
        print(sess.run(weights))#正確,因為assign操作會初始化相關的節點
        print(sess.run(y))#錯誤,因為使用了未初始化的biases變量

 例子2:tf.assign()操作可以改變變量的shape,只需要令參數validate_shape=False,默認為True。

#-*-coding:utf-8-*-

import tensorflow as tf
x = tf.Variable(0)
y = tf.assign(x, [5,2], validate_shape=False)
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print (sess.run(x))#輸出0
    print (sess.run(y))#輸出[5 2]
    print (sess.run(x))#輸出[5 2]

例子3:assign都會在圖中產生額外的操作,可用tf.Variable.load(value, session)實現從圖外賦值不產生額外的操作。

#-*-coding:utf-8-*-

import tensorflow as tf
x = tf.Variable(0)
sess = tf.Session()
sess.run(tf.global_variables_initializer())
print(sess.run(x))  # 輸出 0
x.load(5, sess)
print(sess.run(x))  # 輸出 5

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM