MxNet 模型轉Tensorflow pb模型


用mmdnn實現模型轉換

參考鏈接:https://www.twblogs.net/a/5ca4cadbbd9eee5b1a0713af

  1. 安裝mmdnn
    pip install mmdnn

     

  2. 准備好mxnet模型的.json文件和.params文件, 以InsightFace MxNet r50為例        https://github.com/deepinsight/insightface
  3. 用mmdnn運行命令行
    python -m mmdnn.conversion._script.convertToIR -f mxnet -n model-symbol.json -w model-0000.params -d resnet50 --inputShape 3,112,112 

     

     會生成resnet50.json(可視化文件) resnet50.npy(權重參數) resnet50.pb(網絡結構)三個文件。

  4. 用mmdnn運行命令行
    python -m mmdnn.conversion._script.IRToCode -f tensorflow --IRModelPath resnet50.pb --IRWeightPath resnet50.npy --dstModelPath tf_resnet50.py 

     

     生成tf_resnet50.py文件,可以調用tf_resnet50.py中的KitModel函數加載npy權重參數重新生成原網絡框架。

  5. 打開tf_resnet.py文件,修改load_weights()中的代碼 (tensorflow=1.14.0報錯) 

     try:
            weights_dict = np.load(weight_file).item()
        except:
            weights_dict = np.load(weight_file, encoding='bytes').item()

    改為

     try:
            weights_dict = np.load(weight_file, allow_pickle=True).item()
    except:
            weights_dict = np.load(weight_file, allow_pickle=True, encoding='bytes').item()

     

  6. 基於resnet50.npy和tf_resnet50.py文​​件,固化參數,生成PB文件:

    import tensorflow as tf
    import tf_resnet50 as tf_fun
    def netWork():
        model=tf_fun.KitModel("./resnet50.npy")
        return model
    def freeze_graph(output_graph):
        output_node_names = "output"
        data,fc1=netWork()
        fc1=tf.identity(fc1,name="output")
    
        graph = tf.get_default_graph()  # 獲得默認的圖
        input_graph_def = graph.as_graph_def()  # 返回一個序列化的圖代表當前的圖
        init = tf.global_variables_initializer()
        with tf.Session() as sess:
            sess.run(init)
            output_graph_def = tf.graph_util.convert_variables_to_constants(  # 模型持久化,將變量值固定
                sess=sess,
                input_graph_def=input_graph_def,  # 等於:sess.graph_def
                output_node_names=output_node_names.split(","))  # 如果有多個輸出節點,以逗號隔開
    
            with tf.gfile.GFile(output_graph, "wb") as f:  # 保存模型
                f.write(output_graph_def.SerializeToString())  # 序列化輸出
    
    if __name__ == '__main__':
        freeze_graph("frozen_insightface_r50.pb")
        print("finish!")

     

  7. 采用tensorflow的post-train quantization離線量化方法(有一定的精度損失)轉換成tflite模型,從而完成端側的模型部署:
    import tensorflow as tf
    
    convert=tf.lite.TFLiteConverter.from_frozen_graph("frozen_insightface_r50.pb",input_arrays=["data"],output_arrays=["output"],
                                                      input_shapes={"data":[1,112,112,3]})
    convert.post_training_quantize=True
    tflite_model=convert.convert()
    open("quantized_insightface_r50.tflite","wb").write(tflite_model)
    print("finish!")

     


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM