本節涉及:
- 保存TensorFlow 的模型供其他語言使用
- java中調用模型並進行預測計算
一、保存TensorFlow 的模型供其他語言使用
如果用戶選擇“y” ,則執行下面的步驟:
- 判斷程序執行目錄下是否有 export 目錄,如果有,調用 shutil 包中的 rmtress 函數將其刪除,以免沖突
- builder = tf .saved_model . builder . SavedModelBuilder ("export") ———— 用於生成保存神經網絡模型的對象builder,並指定保存位置為程序執行目錄下的 export 子目錄
- builder.add_meta_graph_and_variables (sess,["tag"]) ———— 指定保存會話對象 sess 中的默認數據流圖和可變參數(即保存模型的主要內容),並起標記名 “tag”,這個標記名 在以后被其他語言調用時會被引用
- builder.save() ———— 保存
完后,會在程序執行目錄下生成一個 export 子目錄,其中包含了需要傳遞給其他語言程序的神經網絡模型的相關文件。
在其他語言調用時,需要把這個文件夾 整個復制到需要使用的計算機上
二、java中調用模型並進行預測計算
調用模型文件進行預測的示例:
import org.tensorflow.Graph; import org.tensorflow.Session; import org.tensorflow.Tensor; import org.tensorflow.TensorFlow; import org.tensorflow.SavedModelBundle; import java.nio.FloatBuffer; import java.util.Arrays; public class TestTF { public static void main(String[] args) { SavedModelBundle smb = SavedModelBundle.load("export", "tag"); Session s = smb.session(); float[][] matrix = {{1.0F, 2.0F, 3.0F, 4.0F}}; System.out.println(Arrays.deepToString(matrix)); Tensor xFeed = Tensor.create(matrix); Tensor result = s.runner().feed("x", xFeed).fetch("y").run().get(0); FloatBuffer buf = FloatBuffer.allocate(2); result.writeTo(buf); System.out.println(result.toString()); System.out.println(buf.get(0)); System.out.println(buf.get(1)); } }