Tensorflow 保存模型 & 在java中調用


本節涉及:

  1. 保存TensorFlow 的模型供其他語言使用
  2. java中調用模型並進行預測計算

一、保存TensorFlow 的模型供其他語言使用

 

 如果用戶選擇“y” ,則執行下面的步驟:

  • 判斷程序執行目錄下是否有 export 目錄,如果有,調用 shutil 包中的 rmtress 函數將其刪除,以免沖突
  • builder = tf .saved_model . builder . SavedModelBuilder ("export")   ———— 用於生成保存神經網絡模型的對象builder,並指定保存位置為程序執行目錄下的 export 子目錄
  • builder.add_meta_graph_and_variables (sess,["tag"]) ———— 指定保存會話對象 sess 中的默認數據流圖和可變參數(即保存模型的主要內容),並起標記名 “tag”,這個標記名 在以后被其他語言調用時會被引用
  • builder.save() ———— 保存

 

完后,會在程序執行目錄下生成一個 export 子目錄,其中包含了需要傳遞給其他語言程序的神經網絡模型的相關文件。

在其他語言調用時,需要把這個文件夾 整個復制到需要使用的計算機上

二、java中調用模型並進行預測計算

 

 

 調用模型文件進行預測的示例:

import org.tensorflow.Graph;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.tensorflow.TensorFlow;
import org.tensorflow.SavedModelBundle;
import java.nio.FloatBuffer;
import java.util.Arrays;

public class TestTF {

    public static void main(String[] args) {  
        SavedModelBundle smb = SavedModelBundle.load("export", "tag");
        
        Session s = smb.session();
        
        float[][] matrix = {{1.0F, 2.0F, 3.0F, 4.0F}};
        System.out.println(Arrays.deepToString(matrix));  

        Tensor xFeed = Tensor.create(matrix);
        
        Tensor result = s.runner().feed("x", xFeed).fetch("y").run().get(0);
            
        FloatBuffer buf = FloatBuffer.allocate(2);
            
        result.writeTo(buf);
            
        System.out.println(result.toString());  

        System.out.println(buf.get(0));  
        System.out.println(buf.get(1));  
    }
}

 

 

 

 

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM