tensorflow java 調用pb模型預測實例(CVSample 鑒黃檢測)


主要記錄以下輸入、輸出參數處理過程,其他初始化百度資料很多。

背景

項目中用到鑒黃識別,從Github上找到了別人訓練好的pb模型,項目地址: https://github.com/kingroc711/CVSample/tree/master/TensorFlow/inception_model

但是項目中只提供了python代碼,首先對python不熟悉,並且發現tensorflow提供了對java預測模型的支持,並且項目使用的是java,所以想把tensorflow 集成到項目中,調用pb模型預測。

但通過tensorboard工具查看模型時發現輸入參數為string,雖然可以跑通,但到現在也不理解入參為什么設計成string類型.

 

pb文件參數(output_graph.pb)

在調用模型之前,需要先清楚模型輸入、輸出參數類型。

輸入名稱:DecodeJpeg/contents:0    類型: string,實際傳入圖片文件原始數據就可以
輸出名稱:final_result:0       類型: float

這個文件的輸入、輸出參數類型,通過CVSample項目庫中python調用代碼,找到輸入、輸出名稱

也可以先用python生成日志,通過tensorboard工具分析日志,拿到模型輸入、輸出參數

 

推薦參考示例(LabelImage):

tensorflow 官方有一個labelImg的java示例,如果第一次使用tensorflow java api,應該會對你有用: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/java/src/main/java/org/tensorflow/examples/LabelImage.java

如果想運行這個示例,下載示例中提到的模型: https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip

這個示例中,輸入、輸出參數和鑒黃識別模型參數不太一樣,所以也會有一些區別。

在這個示例中,對圖片進行了一些圖像預處理。 

圖像是否需要預處理,需要看模型,有些模型需要,有些不需要(比如這個鑒黃模型)。

 

精簡代碼:

tensorflow: 1.15.0

        <dependency>
            <groupId>org.tensorflow</groupId>
            <artifactId>tensorflow</artifactId>
            <version>1.15.0</version>
        </dependency>

import org.tensorflow.Graph;
import org.tensorflow.Output;
import org.tensorflow.Session;
import org.tensorflow.Tensor;

import java.io.File;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.Arrays;

public static void main(String[] args) throws IOException { try (Graph g = new Graph()) { //pb 模型文件 byte modelBytes[] = Files.readAllBytes(new File("/opt/work/java_work/tensorflow_demo/inception_model/output_graph.pb").toPath()); g.importGraphDef(modelBytes); try (Session s = new Session(g)) { //生成輸入參數,此處生成從 https://github.com/tensorflow/tensorflow/blob/v1.15.0/tensorflow/java/src/test/java/org/tensorflow/TensorTest.java 中找到的方法 Tensor<String> tensor = (Tensor<String>) Tensor.create(Files.readAllBytes(Paths.get("/root/test.png"))); Tensor<Float> result = s.runner() //輸入參數 .feed("DecodeJpeg/contents:0", tensor) //輸出參數 .fetch("final_result:0") .run() .get(0) .expect(Float.class); //存儲結果容器, 輸出固定有5條數據,分別是每個分類(0:porn 1:neutral 2:hentai 3:drawings 4:sexy)的分數 float[][] values = new float[1][5]; result.copyTo(values); System.out.println(Arrays.toString(values[0])); //結果[0.027002065, 0.8941082, 0.02338332, 0.044249564, 0.011256761] //porn(色情): 0.027002065, neutral(正常): 0.8941082, hentai: 0.02338332, drawings: 0.044249564, sexy(性感): 0.011256761 } } }

 

 

前前后后為了生成輸入參數查了一周,網上資料是真的少,為了有相同問題的人可以快速解決,避免和我類似情況出現,所以此處記錄以下。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM