解決tensorflow 的 Saver.restore()無法從本地讀取變量的問題
最近做tensorflow 手寫數字識別的時候遇到了一個問題,Saver的restore()方法無法從本地恢復變量,導致了每次都會重新訓練。
原來代碼
saver = tf.train.Saver(max_to_keep=5)
epoch = tf.Variable(0, name='epoch', trainable=False)
sess = tf.Session()
sess.run(tf.global_variables_initializer())
ckpt_dir = "./model/"
if not os.path.exists(ckpt_dir):
os.makedirs(ckpt_dir)
ckpt = tf.train.latest_checkpoint(ckpt_dir)
if ckpt != None:
saver.restore(sess, ckpt)
else:
print('Train from scratch')
start = sess.run(epoch)
修改代碼
epoch = tf.Variable(0, name='epoch', trainable=False)
saver = tf.train.Saver(max_to_keep=5)
sess = tf.Session()
sess.run(tf.global_variables_initializer())
ckpt_dir = "./model/"
if not os.path.exists(ckpt_dir):
os.makedirs(ckpt_dir)
ckpt = tf.train.latest_checkpoint(ckpt_dir)
if ckpt != None:
saver.restore(sess, ckpt)
else:
print('Train from scratch')
start = sess.run(epoch)
其實主要改變的就是以下兩行的順序
epoch = tf.Variable(0, name='epoch', trainable=False)
saver = tf.train.Saver(max_to_keep=5)