1. 模型參數的保存:
import tensorflow as tf
w=tf.Variable(0.0,name='graph_w')
ww=tf.Variable(tf.random_normal(shape=(2,3),stddev=0.5),name='graph_ww')
# double=tf.multiply(2.0,w)
saver=tf.train.Saver({'weights_w':w,'weights_ww':ww}) # 此處模型文件關鍵字可以自己命名,如weights_w與weights_ww
# 關鍵字所對應的值名字為變量w與ww,而不是graph_w與graph_ww,否則會報錯。{'weights_w':w,'weights_ww':ww}為模型文件
# 需要保存的變量,用字典形式書寫出來,若無此字典,默認保存全部。
sess=tf.Session()
sess.run(tf.global_variables_initializer())
for i in range(4):
d=sess.run(tf.assign_add(w,2)) # 這一步對w進行計算,得到最后值為8,最終將其保存saver種
# 其中 w 必須為變量名為w,不能是graph中的graph_w,否則會報錯
print(d)
print('w=',sess.run(w))
print('ww=',sess.run(ww))
saver.save(sess,'test.ckpt')
2. 模型參數的恢復:
import tensorflow as tf
restore_w=tf.Variable(0.0,name='weights_w')
restore_ww=tf.Variable(tf.random_normal(shape=(2,3),stddev=0.5),name='weights_ww') # 盡管有初始值,但未調用tf.global_variables_initializer()此函數,則不會將其初始值賦值給該變量
# restore_w與restore_ww分別對應保存變量的w與ww的恢復,若想恢復則graph必須是對應保存變量名對應字典的關鍵字,
# 否則將會報錯。即恢復對應變量參數的變量名字可以自己重新命名,但graph中的名字必須是字典關鍵字。
double=tf.multiply(2.0,restore_w)
saver=tf.train.Saver()
sess=tf.Session()
saver.restore(sess,'test.ckpt')
f=sess.run(double)
print(f)
print('restore_ww=',sess.run(restore_ww))
總結:
① 變量初始化有2種方法,若不調用tf.global_variables_initializer()或tf.variables_initializer()函數就不會將變量restore_ww=tf.Variable(tf.random_normal(shape=(2,3),stddev=0.5),name='weights_ww')
初始化。相反,使用saver.restore()將其變量初始化了。同時,也說明變量初始化才不會報錯。
② 模型參數以字典形式保存,其key可自己命名,其value必須為變量(非graph的name)。
③ 模型參數恢復對應變量可以自己命名,但對應變量中graph的name必須是保存模型參數對應變量的關鍵字(key)。