tensorflow中一種融合多個模型的方法


1.使用場景

假設我們有訓練好的模型A,B,C,我們希望使用A,B,C中的部分或者全部變量,合成為一個模型D,用於初始化或其他目的,就需要融合多個模型的方法

 

2.如何實現

我們可以先聲明模型D,再創建多個Saver實例,分別從模型A,B,C的保存文件(checkpoint文件)中讀取所需的變量值,來達成這一目的,下面是示例代碼:

首先創建一個只包含w1,w2兩個變量的模型,初始化后保存:

 1 def train_model1():
 2     w1 = tf.get_variable("w1", shape=[3, 1], initializer=tf.truncated_normal_initializer(),trainable=True)
 3     w2 = tf.get_variable("w2", shape=[3, 1], initializer=tf.truncated_normal_initializer(), trainable=True)
 4     x = tf.placeholder(tf.float32, shape=[None, 3], name="x")
 5     a1 = tf.matmul(x, w1)
 6     input = np.random.rand(3200, 3)
 7     sess = tf.InteractiveSession()
 8     sess.run(tf.global_variables_initializer())
 9     saver1 = tf.train.Saver([w1,w2])
10     for i in range(0, 1):
11         w1_var,w2_var = sess.run([w1,w2], feed_dict={x: input[i * 32:(i + 1) * 32]})
12         print w1_var
13         print w2_var
14         print '=' * 30
15     saver1.save(sess, 'save1-exp')

然后再創建一個只包含w2,w3兩個變量的模型,也是初始化后保存:

 1 def train_model2():
 2     w2 = tf.get_variable("w2", shape=[3, 1], initializer=tf.truncated_normal_initializer(),trainable=True)
 3     w3 = tf.get_variable("w3", shape=[3, 1], initializer=tf.truncated_normal_initializer(),trainable=True)
 4     x = tf.placeholder(tf.float32, shape=[None, 3], name="x")
 5     a2 = tf.matmul(x, w2 * w3)
 6     input = np.random.rand(3200, 3)
 7     sess = tf.InteractiveSession()
 8     sess.run(tf.global_variables_initializer())
 9     saver2 = tf.train.Saver([w2,w3])
10     for i in range(0, 1):
11         w2_var, w3_var = sess.run([w2, w3], feed_dict={x: input[i * 32:(i + 1) * 32]})
12         print w2_var
13         print w3_var
14         print '=' * 30
15     saver2.save(sess, 'save2-exp')

最后我們創建一個包含w1,w2,w3變量的模型,從上面兩個保存的ckp文件中恢復:

 1 def restore_model():
 2     w1 = tf.get_variable("w1", shape=[3, 1], initializer=tf.truncated_normal_initializer(),trainable=True)
 3     w2 = tf.get_variable("w2", shape=[3, 1], initializer=tf.truncated_normal_initializer(),trainable=True)
 4     w3 = tf.get_variable("w3", shape=[3, 1], initializer=tf.truncated_normal_initializer(),trainable=True)
 5     x = tf.placeholder(tf.float32, shape=[None, 3], name="x")
 6     a1 = tf.matmul(x, w1)
 7     a2 = tf.matmul(x, w2 * w3)
 8     loss = tf.reduce_mean(tf.square(a1 - a2))
 9     sess = tf.InteractiveSession()
10     sess.run(tf.global_variables_initializer())
11     saver1 = tf.train.Saver([w1,w2])
12     saver1.restore(sess, 'save1-exp')
13     saver2 = tf.train.Saver([w2, w3])
14     saver2.restore(sess, 'save2-exp')
15     saver3 = tf.train.Saver(tf.trainable_variables())
16     input = np.random.rand(3200, 3)
17     w1_var, w2_var, w3_var = sess.run([w1, w2, w3], feed_dict={x: input[0:32]})
18     print w1_var
19     print w2_var
20     print w3_var
21     print '=' * 30
22     saver3.save(sess, 'save3-exp')

然后保存,即完成了我們的目標

 

3.注意事項

3.1 取的模型中有同名變量

假設同名變量為a,這種情況下,從不同模型中恢復的a是按照讀取順序覆蓋到a中的,如果希望只讀取特定ckpt保存的變量值,在創建讀取其他ckpt的saver時,不要把a加入到var_list中

3.2 模型D中有部分變量不在A,B,C中

這種情況,恢復時會報錯,需要指定var_list,只恢復當前cpkt中保存的變量


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM