關於Tensorflow 加載和使用多個模型的方式


  在Tensorflow中,所有操作對象都包裝到相應的Session中的,所以想要使用不同的模型就需要將這些模型加載到不同的Session中並在使用的時候申明是哪個Session,從而避免由於Session和想使用的模型不匹配導致的錯誤。而使用多個graph,就需要為每個graph使用不同的Session,但是每個graph也可以在多個Session中使用,這個時候就需要在每個Session使用的時候明確申明使用的graph。

g1 = tf.Graph() # 加載到Session 1的graph
g2 = tf.Graph() # 加載到Session 2的graph

sess1 = tf.Session(graph=g1) # Session1
sess2 = tf.Session(graph=g2) # Session2

# 加載第一個模型
with sess1.as_default(): 
    with g1.as_default():
        tf.global_variables_initializer().run()
        model_saver = tf.train.Saver(tf.global_variables())
        model_ckpt = tf.train.get_checkpoint_state(“model1/save/path”)
        model_saver.restore(sess, model_ckpt.model_checkpoint_path)
# 加載第二個模型
with sess2.as_default():  # 1
    with g2.as_default():  
        tf.global_variables_initializer().run()
        model_saver = tf.train.Saver(tf.global_variables())
        model_ckpt = tf.train.get_checkpoint_state(“model2/save/path”)
        model_saver.restore(sess, model_ckpt.model_checkpoint_path)

...

# 使用的時候
with sess1.as_default():
    with sess1.graph.as_default():  # 2
        ...

with sess2.as_default():
    with sess2.graph.as_default():
        ...

# 關閉sess
sess1.close()
sess2.close()

注意事項:

1、在1處使用as_default使session在離開的時候並不關閉,在后面可以繼續使用知道手動關閉;

2、由於有多個graph,所以sess.graph與tf.get_default_value的值是不相等的,因此在進入sess的時候必須sess.graph.as_default()明確申明sess.graph為當前默認graph,否則就會報錯。

3、不同框架的模型(tf, caffe, torch等)在加載的很有可能導致底層的cuDNN分配出現問題從而報錯,這種一般可以嘗試通過模型的加載順序來解決。

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM