tf.Variable
__init__(
initial_value=None,
trainable=True,
collections=None,
validate_shape=True,
caching_device=None,
name=None,
variable_def=None,
dtype=None,
expected_shape=None,
import_scope=None
)
功能說明:
維護圖在執行過程中的狀態信息,例如神經網絡權重值的變化。
參數列表:
| 參數名 | 類型 | 說明 |
|---|---|---|
| initial_value | 張量 | Variable 類的初始值,這個變量必須指定 shape 信息,否則后面 validate_shape 需設為 False |
| trainable | Boolean | 是否把變量添加到 collection GraphKeys.TRAINABLE_VARIABLES 中(collection 是一種全局存儲,不受變量名生存空間影響,一處保存,到處可取) |
| collections | Graph collections | 全局存儲,默認是 GraphKeys.GLOBAL_VARIABLES |
| validate_shape | Boolean | 是否允許被未知維度的 initial_value 初始化 |
| caching_device | string | 指明哪個 device 用來緩存變量 |
| name | string | 變量名 |
| dtype | dtype | 如果被設置,初始化的值就會按照這個類型初始化 |
| expected_shape | TensorShape | 要是設置了,那么初始的值會是這種維度 |
示例代碼:
import tensorflow as tf initial= tf.truncated_normal(shape=[10,10],mean=0,stddev=1) W=tf.Variable(initial) list=[[1.,1.],[2.,2.]] X=tf.Variable(list,dtype=tf.float32) ini_op=tf.global_variables_initializer() with tf.Session() as sess: sess.run(ini_op) print(sess.run(W[:2,:2])) op=W[:2,:2].assign(22.*tf.ones((2,2))) print(sess.run(op)) print (W.eval()) #Usage with the default session print ("#####################(6)#############") print (W.dtype) print (sess.run(W.initial_value)) print (sess.run(W.op)) print (W.shape) print ("###################(7)###############") print (sess.run(X))
