TensorFlow Lite for Android示例


一、TensorFlow  Lite

TensorFlow Lite 是用於移動設備和嵌入式設備的輕量級解決方案。TensorFlow Lite 支持 Android、iOS 甚至樹莓派等多種平台。

二、tflite格式

TensorFlow 生成的模型是無法直接給移動端使用的,需要離線轉換成.tflite文件格式。

tflite 存儲格式是 flatbuffers。

FlatBuffers 是由Google開源的一個免費軟件庫,用於實現序列化格式。它類似於Protocol Buffers、Thrift、Apache Avro。

因此,如果要給移動端使用的話,必須把 TensorFlow 訓練好的 protobuf 模型文件轉換成 FlatBuffers 格式。官方提供了 toco 來實現模型格式的轉換。

三、API

TensorFlow Lite 提供了 C ++ 和 Java 兩種類型的 API。無論哪種 API 都需要加載模型和運行模型。

而 TensorFlow Lite 的 Java API 使用了 Interpreter 類(解釋器)來完成加載模型和運行模型的任務。后面的例子會看到如何使用 Interpreter。

四、TensorFlow Lite實現手寫數字識別

下面的 demo 中已經包含了 mnist.tflite 模型文件。(如果沒有的話,需要自己訓練保存成pb文件,再轉換成tflite 格式)
對於一個識別類,首先需要初始化 TensorFlow Lite 解釋器,以及輸入、輸出。
    // The tensorflow lite file
    private lateinit var tflite: Interpreter

    // Input byte buffer
    private lateinit var inputBuffer: ByteBuffer

    // Output array [batch_size, 10]
    private lateinit var mnistOutput: Array<FloatArray>

    init {

        try {
            tflite = Interpreter(loadModelFile(activity))

            inputBuffer = ByteBuffer.allocateDirect(
                    BYTE_SIZE_OF_FLOAT * DIM_BATCH_SIZE * DIM_IMG_SIZE_X * DIM_IMG_SIZE_Y * DIM_PIXEL_SIZE)
            inputBuffer.order(ByteOrder.nativeOrder())
            mnistOutput = Array(DIM_BATCH_SIZE) { FloatArray(NUMBER_LENGTH) }
            Log.d(TAG, "Created a Tensorflow Lite MNIST Classifier.")
        } catch (e: IOException) {
            Log.e(TAG, "IOException loading the tflite file failed.")
        }

    }

從 asserts 文件中加載 mnist.tflite 模型:

    /**
     * Load the model file from the assets folder
     */
    @Throws(IOException::class)
    private fun loadModelFile(activity: Activity): MappedByteBuffer {

        val fileDescriptor = activity.assets.openFd(MODEL_PATH)
        val inputStream = FileInputStream(fileDescriptor.fileDescriptor)
        val fileChannel = inputStream.channel
        val startOffset = fileDescriptor.startOffset
        val declaredLength = fileDescriptor.declaredLength
        return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength)
    }

真正識別手寫數字是在 classify() 方法:

val digit = mnistClassifier.classify(Bitmap.createScaledBitmap(paintView.bitmap, PIXEL_WIDTH, PIXEL_WIDTH, false))

classify() 方法包含了預處理用於初始化 inputBuffer、運行 mnist 模型、識別出數字。

    /**
     * Classifies the number with the mnist model.
     *
     * @param bitmap
     * @return the identified number
     */
    fun classify(bitmap: Bitmap): Int {

        if (tflite == null) {
            Log.e(TAG, "Image classifier has not been initialized; Skipped.")
        }

        preProcess(bitmap)
        runModel()
        return postProcess()
    }

    /**
     * Converts it into the Byte Buffer to feed into the model
     *
     * @param bitmap
     */
    private fun preProcess(bitmap: Bitmap?) {

        if (bitmap == null || inputBuffer == null) {
            return
        }

        // Reset the image data
        inputBuffer.rewind()

        val width = bitmap.width
        val height = bitmap.height

        // The bitmap shape should be 28 x 28
        val pixels = IntArray(width * height)
        bitmap.getPixels(pixels, 0, width, 0, 0, width, height)

        for (i in pixels.indices) {
            // Set 0 for white and 255 for black pixels
            val pixel = pixels[i]
            // The color of the input is black so the blue channel will be 0xFF.
            val channel = pixel and 0xff
            inputBuffer.putFloat((0xff - channel).toFloat())
        }
    }

    /**
     * Run the TFLite model
     */
    private fun runModel() = tflite.run(inputBuffer, mnistOutput)

    /**
     * Go through the output and find the number that was identified.
     *
     * @return the number that was identified (returns -1 if one wasn't found)
     */
    private fun postProcess(): Int {

        for (i in 0 until mnistOutput[0].size) {
            val value = mnistOutput[0][i]
            if (value == 1f) {
                return i
            }
        }

        return -1
    }

對於 Android 有一個地方需要注意,必須在 app 模塊的 build.gradle 中添加如下的語句,否則無法加載模型。

android {
    ......
    aaptOptions {
        noCompress "tflite"
    }
}

效果:

 

 五、總結

本文 demo 的 github 地址:https://github.com/fengzhizi715/TFLite-MnistDemo

當然,也可以跑一下官方的例子:https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/examples/android/app

雖然准確度都不咋地。。。

更多有趣的TensorFlow Lite示例:https://www.tensorflow.org/lite/examples/

 

 

參考鏈接:https://www.jianshu.com/p/e96f80c80e43

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM