在Tensorflow中,所有操作對象都包裝到相應的Session中的,所以想要使用不同的模型就需要將這些模型加載到不同的Session中並在使用的時候申明是哪個Session,從而避免由於Session和想使用的模型不匹配導致的錯誤。而使用多個graph,就需要為每個graph使用不同的Session,但是每個graph也可以在多個Session中使用,這個時候就需要在每個Session使用的時候明確申明使用的graph。
g1 = tf.Graph() # 加載到Session 1的graph g2 = tf.Graph() # 加載到Session 2的graph sess1 = tf.Session(graph=g1) # Session1 sess2 = tf.Session(graph=g2) # Session2 # 加載第一個模型 with sess1.as_default(): with g1.as_default(): tf.global_variables_initializer().run() model_saver = tf.train.Saver(tf.global_variables()) model_ckpt = tf.train.get_checkpoint_state(“model1/save/path”) model_saver.restore(sess, model_ckpt.model_checkpoint_path) # 加載第二個模型 with sess2.as_default(): # 1 with g2.as_default(): tf.global_variables_initializer().run() model_saver = tf.train.Saver(tf.global_variables()) model_ckpt = tf.train.get_checkpoint_state(“model2/save/path”) model_saver.restore(sess, model_ckpt.model_checkpoint_path) ... # 使用的時候 with sess1.as_default(): with sess1.graph.as_default(): # 2 ... with sess2.as_default(): with sess2.graph.as_default(): ... # 關閉sess sess1.close() sess2.close()
注意事項:
1、在1處使用as_default使session在離開的時候並不關閉,在后面可以繼續使用知道手動關閉;
2、由於有多個graph,所以sess.graph與tf.get_default_value的值是不相等的,因此在進入sess的時候必須sess.graph.as_default()明確申明sess.graph為當前默認graph,否則就會報錯。
3、不同框架的模型(tf, caffe, torch等)在加載的很有可能導致底層的cuDNN分配出現問題從而報錯,這種一般可以嘗試通過模型的加載順序來解決。