tensorflow 模型權重導出


tensorflow在保存權重模型時多使用tf.train.Saver().save 函數進行權重保存,保存的ckpt文件無法直接打開,不利於將模型權重導入到其他框架使用(如Caffe、Keras等)。

好在tensorflow提供了相關函數 tf.train.NewCheckpointReader 可以對ckpt文件進行權重查看,因此可以通過該函數進行數據導出。

 1 import tensorflow as tf
 2 import h5py
 3 
 4 cpktLogFileName = r'./checkpoint/checkpoint'  #cpkt 文件路徑
 5 with open(cpktLogFileName, 'r') as f:
 6     #權重節點往往會保留多個epoch的數據,此處獲取最后的權重數據      
 7     cpktFileName = f.readline().split('"')[1]     
 8 
 9 h5FileName = r'./model/net_classification.h5'
10 
11 reader = tf.train.NewCheckpointReader(cpktFileName)
12 f = h5py.File(h5FileName, 'w')
13 t_g = None
14 for key in sorted(reader.get_variable_to_shape_map()):
15    # 權重名稱需根據自己網絡名稱自行修改
16    if key.endswith('w') or key.endswith('biases'):
17         keySplits = key.split(r'/')
18         keyDict = keySplits[1] + '/' + keySplits[1] + '/' + keySplits[2]
19         f[keyDict] = reader.get_tensor(key)

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM