tf.Variable
__init__( initial_value=None, trainable=True, collections=None, validate_shape=True, caching_device=None, name=None, variable_def=None, dtype=None, expected_shape=None, import_scope=None )
功能說明:
維護圖在執行過程中的狀態信息,例如神經網絡權重值的變化。
參數列表:
參數名 | 類型 | 說明 |
---|---|---|
initial_value | 張量 | Variable 類的初始值,這個變量必須指定 shape 信息,否則后面 validate_shape 需設為 False |
trainable | Boolean | 是否把變量添加到 collection GraphKeys.TRAINABLE_VARIABLES 中(collection 是一種全局存儲,不受變量名生存空間影響,一處保存,到處可取) |
collections | Graph collections | 全局存儲,默認是 GraphKeys.GLOBAL_VARIABLES |
validate_shape | Boolean | 是否允許被未知維度的 initial_value 初始化 |
caching_device | string | 指明哪個 device 用來緩存變量 |
name | string | 變量名 |
dtype | dtype | 如果被設置,初始化的值就會按照這個類型初始化 |
expected_shape | TensorShape | 要是設置了,那么初始的值會是這種維度 |
示例代碼:
import tensorflow as tf initial= tf.truncated_normal(shape=[10,10],mean=0,stddev=1) W=tf.Variable(initial) list=[[1.,1.],[2.,2.]] X=tf.Variable(list,dtype=tf.float32) ini_op=tf.global_variables_initializer() with tf.Session() as sess: sess.run(ini_op) print(sess.run(W[:2,:2])) op=W[:2,:2].assign(22.*tf.ones((2,2))) print(sess.run(op)) print (W.eval()) #Usage with the default session print ("#####################(6)#############") print (W.dtype) print (sess.run(W.initial_value)) print (sess.run(W.op)) print (W.shape) print ("###################(7)###############") print (sess.run(X))