解決tensorflow Saver.restore()無效的問題


解決tensorflow 的 Saver.restore()無法從本地讀取變量的問題

最近做tensorflow 手寫數字識別的時候遇到了一個問題,Saver的restore()方法無法從本地恢復變量,導致了每次都會重新訓練。

原來代碼

saver = tf.train.Saver(max_to_keep=5)
epoch = tf.Variable(0, name='epoch', trainable=False)

sess = tf.Session()
sess.run(tf.global_variables_initializer())


ckpt_dir = "./model/"

if not os.path.exists(ckpt_dir):
    os.makedirs(ckpt_dir)


ckpt = tf.train.latest_checkpoint(ckpt_dir)

if ckpt != None:
    saver.restore(sess, ckpt)
else:
    print('Train from scratch')

start = sess.run(epoch)

修改代碼

epoch = tf.Variable(0, name='epoch', trainable=False)
saver = tf.train.Saver(max_to_keep=5)

sess = tf.Session()
sess.run(tf.global_variables_initializer())

ckpt_dir = "./model/"

if not os.path.exists(ckpt_dir):
    os.makedirs(ckpt_dir)


ckpt = tf.train.latest_checkpoint(ckpt_dir)

if ckpt != None:
    saver.restore(sess, ckpt)
else:
    print('Train from scratch')

start = sess.run(epoch)

其實主要改變的就是以下兩行的順序

epoch = tf.Variable(0, name='epoch', trainable=False)
saver = tf.train.Saver(max_to_keep=5)


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM