解决tensorflow Saver.restore()无效的问题


解决tensorflow 的 Saver.restore()无法从本地读取变量的问题

最近做tensorflow 手写数字识别的时候遇到了一个问题,Saver的restore()方法无法从本地恢复变量,导致了每次都会重新训练。

原来代码

saver = tf.train.Saver(max_to_keep=5)
epoch = tf.Variable(0, name='epoch', trainable=False)

sess = tf.Session()
sess.run(tf.global_variables_initializer())


ckpt_dir = "./model/"

if not os.path.exists(ckpt_dir):
    os.makedirs(ckpt_dir)


ckpt = tf.train.latest_checkpoint(ckpt_dir)

if ckpt != None:
    saver.restore(sess, ckpt)
else:
    print('Train from scratch')

start = sess.run(epoch)

修改代码

epoch = tf.Variable(0, name='epoch', trainable=False)
saver = tf.train.Saver(max_to_keep=5)

sess = tf.Session()
sess.run(tf.global_variables_initializer())

ckpt_dir = "./model/"

if not os.path.exists(ckpt_dir):
    os.makedirs(ckpt_dir)


ckpt = tf.train.latest_checkpoint(ckpt_dir)

if ckpt != None:
    saver.restore(sess, ckpt)
else:
    print('Train from scratch')

start = sess.run(epoch)

其实主要改变的就是以下两行的顺序

epoch = tf.Variable(0, name='epoch', trainable=False)
saver = tf.train.Saver(max_to_keep=5)


免责声明!

本站转载的文章为个人学习借鉴使用,本站对版权不负任何法律责任。如果侵犯了您的隐私权益,请联系本站邮箱yoyou2525@163.com删除。



 
粤ICP备18138465号  © 2018-2025 CODEPRJ.COM