【6】TensorFlow光速入門-python模型轉換為tfjs模型並使用


本文地址:https://www.cnblogs.com/tujia/p/13862365.html

 

系列文章:

【0】TensorFlow光速入門-序

【1】TensorFlow光速入門-tensorflow開發基本流程

【2】TensorFlow光速入門-數據預處理(得到數據集)

【3】TensorFlow光速入門-訓練及評估

【4】TensorFlow光速入門-保存模型及加載模型並使用

【5】TensorFlow光速入門-圖片分類完整代碼

【6】TensorFlow光速入門-python模型轉換為tfjs模型並使用

【7】TensorFlow光速入門-總結

 

一、模型轉換

python模型轉換tfjs模型,需要用到先安裝 tensorflowjs_converter 工具

 pip install tensorflowjs

安裝成功后,可以用python腳本或shell命令轉換,下面是shell的例子:

tensorflowjs_converter /tf/saved_model/wnw /tf/saved_model_js/wnw

注:記得提前創建好 saved_model_js 目錄。轉換成功成,會得到 model.json 及 n個 .bin 文件,例如:group1-shard1of2.bin、group1-shard2of2.bin等等

 

轉換命令的詳細參數,請看:

tensorflowjs_converter --help

 

二、在瀏覽器中使用

先准備好模型文件及dict.txt 文件(注:dict.txt 需要自己創建,內容為圖片分類,一行一個分類)

     

基礎html元素

<input type="file" class="custom-file-input" id="file" accept="image/*" capture="camera">

<input class="form-control" id="result" readonly="readonly">

<img src="" id="pic">

引入tfjs

<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@2.6.0/dist/tf.min.js"></script>

加載模型

const MODEL_URL = '/static/models/wnw/model.json';

let model = null;
tf.loadGraphModel(MODEL_URL).then((value)=>{
    model = value;
}, (error)=>{
    console.log(error);
});

偵聽選擇圖片及圖片預覽

let image = document.getElementById('pic');

// 圖片預覽
document.getElementById('file').addEventListener('change', (ev)=>{
    let reader = new FileReader();
    reader.addEventListener('load', (e)=>{
        image.src = e.target.result;
    });
    reader.readAsDataURL(ev.target.files[0]);
});

圖片數據轉換及預測

// 圖片分類
const CLASSIFY = ['非表', '表'];

// 圖片處理及評估
image.addEventListener('load', ()=>{
    // 圖片轉換成灰度張量數據
    let image_tensor = tf.browser.fromPixels(image, 1);
    // 三維張量轉四維張量
    image_tensor = tf.expandDims(image_tensor);
    image_tensor = tf.cast(image_tensor, 'float32');
    // console.log(image_tensor.shape);
    // 圖片縮放,轉換為模型需要的大小
    image_tensor = tf.image.resizeBilinear(image_tensor, [100, 100]);
    // console.log(image_tensor.shape);
    let predictions = model.predict(image_tensor);
    let label = tf.argMax(predictions, 1).dataSync()[0];
    result.value = CLASSIFY[label];
});

注:像【4】TensorFlow光速入門-保存模型及加載模型並使用 說的那樣,加載模型,然后准備一個和訓練集一樣格式的數據(數據格式轉換、縮放),然后預測就可以了

 

重點:

tf.browser.fromPixels    base64格式轉tensor3D格式

tf.expandDims         tensor3D格式車轉tensor4D格式

tf.cast            數值轉換,上面例子是int32轉float32

tf.image.resizeBilinear            圖片縮放

model.predict        模型預測

tf.argMax(predictions, 1).dataSync()[0]     取預測結果的最大值的 key(即分類label)

 


 

其他:

官方關於tfjs的使用示例並不完善,甚至是錯。各種跳轉,又是 MobileNet 又是 ml5 的,其實都不需要,直接用 tf.min.js 就可以了。mobilenet 和 ml5 的用法以后再研究

下面是相當混亂的一些相關文檔:

https://tensorflow.google.cn/js/tutorials/conversion/import_keras

https://tensorflow.google.cn/js/tutorials/conversion/import_saved_model

https://github.com/tensorflow/tfjs-converter/blob/master/tfjs-converter/README.md

https://github.com/tensorflow/tfjs-converter/tree/master/tfjs-converter/demo/mobilenet

https://learn.ml5js.org/#/tutorials/hello-ml5?id=demo

 

本文鏈接:https://www.cnblogs.com/tujia/p/13862365.html


 完。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM