Sklearn,TensorFlow,keras模型保存與讀取


一、sklearn模型保存與讀取 
1、保存

1 from sklearn.externals import joblib
2 from sklearn import svm
3 X = [[0, 0], [1, 1]]
4 y = [0, 1]
5 clf = svm.SVC()
6 clf.fit(X, y)  
7 joblib.dump(clf, "train_model.m")

2、讀取

1 clf = joblib.load("train_model.m")
2 clf.predit([0,0]) #此處test_X為特征集

 

二、TensorFlow模型保存與讀取(該方式tensorflow只能保存變量而不是保存整個網絡,所以在提取模型時,我們還需要重新第一網絡結構。) 
1、保存

 1 import tensorflow as tf  
 2 import numpy as np  
 3 
 4 W = tf.Variable([[1,1,1],[2,2,2]],dtype = tf.float32,name='w')  
 5 b = tf.Variable([[0,1,2]],dtype = tf.float32,name='b')  
 6  
 7 init = tf.initialize_all_variables()  
 8 saver = tf.train.Saver()  
 9 with tf.Session() as sess:  
10          sess.run(init)  
11          save_path = saver.save(sess,"save/model.ckpt")  

2、加載

1 import tensorflow as tf  
2 import numpy as np  
3  
4 W = tf.Variable(tf.truncated_normal(shape=(2,3)),dtype = tf.float32,name='w')  
5 b = tf.Variable(tf.truncated_normal(shape=(1,3)),dtype = tf.float32,name='b')  
6  
7 saver = tf.train.Saver()  
8 with tf.Session() as sess:  
9       saver.restore(sess,"save/model.ckpt")  

 

三、TensorFlow模型保存與讀取(該方式tensorflow保存整個網絡) 
1、保存

 1 import tensorflow as tf
 2 
 3 # First, you design your mathematical operations
 4 # We are the default graph scope
 5 
 6 # Let's design a variable
 7 v1 = tf.Variable(1. , name="v1")
 8 v2 = tf.Variable(2. , name="v2")
 9 # Let's design an operation
10 a = tf.add(v1, v2)
11 
12 # Let's create a Saver object
13 # By default, the Saver handles every Variables related to the default graph
14 all_saver = tf.train.Saver() 
15 # But you can precise which vars you want to save under which name
16 v2_saver = tf.train.Saver({"v2": v2}) 
17 
18 # By default the Session handles the default graph and all its included variables
19 with tf.Session() as sess:
20   # Init v and v2   
21   sess.run(tf.global_variables_initializer())
22   # Now v1 holds the value 1.0 and v2 holds the value 2.0
23   # We can now save all those values
24   all_saver.save(sess, 'data.chkp')
25   # or saves only v2
26   v2_saver.save(sess, 'data-v2.chkp')
27 模型的權重是保存在 .chkp 文件中,模型的圖是保存在 .chkp.meta 文件中。

2、加載

 1 import tensorflow as tf
 2 
 3 # Let's laod a previous meta graph in the current graph in use: usually the default graph
 4 # This actions returns a Saver
 5 saver = tf.train.import_meta_graph('results/model.ckpt-1000.meta')
 6 
 7 # We can now access the default graph where all our metadata has been loaded
 8 graph = tf.get_default_graph()
 9 
10 # Finally we can retrieve tensors, operations, etc.
11 global_step_tensor = graph.get_tensor_by_name('loss/global_step:0')
12 train_op = graph.get_operation_by_name('loss/train_op')
13 hyperparameters = tf.get_collection('hyperparameters')
14 
15 恢復權重
16 
17 請記住,在實際的環境中,真實的權重只能存在於一個會話中。也就是說,restore 這個操作必須在一個會話中啟動,然后將數據權重導入到圖中。理解恢復操作的最好方法是將它簡單的看做是一種數據初始化操作。
18 with tf.Session() as sess:
19     # To initialize values with saved data
20     saver.restore(sess, 'results/model.ckpt-1000-00000-of-00001')
21     print(sess.run(global_step_tensor)) # returns 1000

 

四、keras模型保存和加載

1 model.save('my_model.h5')  
2 model = load_model('my_model.h5') 

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM