主要記錄以下輸入、輸出參數處理過程,其他初始化百度資料很多。
背景
項目中用到鑒黃識別,從Github上找到了別人訓練好的pb模型,項目地址: https://github.com/kingroc711/CVSample/tree/master/TensorFlow/inception_model
但是項目中只提供了python代碼,首先對python不熟悉,並且發現tensorflow提供了對java預測模型的支持,並且項目使用的是java,所以想把tensorflow 集成到項目中,調用pb模型預測。
但通過tensorboard工具查看模型時發現輸入參數為string,雖然可以跑通,但到現在也不理解入參為什么設計成string類型.
pb文件參數(output_graph.pb)
在調用模型之前,需要先清楚模型輸入、輸出參數類型。
輸入名稱:DecodeJpeg/contents:0 類型: string,實際傳入圖片文件原始數據就可以
輸出名稱:final_result:0 類型: float
這個文件的輸入、輸出參數類型,通過CVSample項目庫中python調用代碼,找到輸入、輸出名稱
也可以先用python生成日志,通過tensorboard工具分析日志,拿到模型輸入、輸出參數
推薦參考示例(LabelImage):
tensorflow 官方有一個labelImg的java示例,如果第一次使用tensorflow java api,應該會對你有用: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/java/src/main/java/org/tensorflow/examples/LabelImage.java
如果想運行這個示例,下載示例中提到的模型: https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip
這個示例中,輸入、輸出參數和鑒黃識別模型參數不太一樣,所以也會有一些區別。
在這個示例中,對圖片進行了一些圖像預處理。
圖像是否需要預處理,需要看模型,有些模型需要,有些不需要(比如這個鑒黃模型)。
精簡代碼:
tensorflow: 1.15.0
<dependency> <groupId>org.tensorflow</groupId> <artifactId>tensorflow</artifactId> <version>1.15.0</version> </dependency>
import org.tensorflow.Graph;
import org.tensorflow.Output;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import java.io.File;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.Arrays;
public static void main(String[] args) throws IOException { try (Graph g = new Graph()) { //pb 模型文件 byte modelBytes[] = Files.readAllBytes(new File("/opt/work/java_work/tensorflow_demo/inception_model/output_graph.pb").toPath()); g.importGraphDef(modelBytes); try (Session s = new Session(g)) { //生成輸入參數,此處生成從 https://github.com/tensorflow/tensorflow/blob/v1.15.0/tensorflow/java/src/test/java/org/tensorflow/TensorTest.java 中找到的方法 Tensor<String> tensor = (Tensor<String>) Tensor.create(Files.readAllBytes(Paths.get("/root/test.png"))); Tensor<Float> result = s.runner() //輸入參數 .feed("DecodeJpeg/contents:0", tensor) //輸出參數 .fetch("final_result:0") .run() .get(0) .expect(Float.class); //存儲結果容器, 輸出固定有5條數據,分別是每個分類(0:porn 1:neutral 2:hentai 3:drawings 4:sexy)的分數 float[][] values = new float[1][5]; result.copyTo(values); System.out.println(Arrays.toString(values[0])); //結果[0.027002065, 0.8941082, 0.02338332, 0.044249564, 0.011256761] //porn(色情): 0.027002065, neutral(正常): 0.8941082, hentai: 0.02338332, drawings: 0.044249564, sexy(性感): 0.011256761 } } }
前前后后為了生成輸入參數查了一周,網上資料是真的少,為了有相同問題的人可以快速解決,避免和我類似情況出現,所以此處記錄以下。