Tensorflow Lite tflite模型的生成與導入


假如想要在ARM板上用 tensorflow lite,那么意味着必須要把PC上的模型生成 tflite文件,然后在ARM上導入這個 tflite文件,通過解析這個文件來進行計算。
根據前面所說, tensorflow的所有計算都會在內部生成一個圖,包括變量的初始化,輸入定義等,那么即便不是經過訓練的神經網絡模型,只是簡單的三角函數計算,也可以生成一個 tflite模型用於在 tensorflow lite上導入。所以,這里我就只做了簡單的 sin()計算來跑一編這個流程。

生成tflite模型

這部分主要是調用TFLiteConverter函數,直接生成tflite文件,不再通過pb文件轉化。
先上代碼:

import numpy as np
import time
import math
import tensorflow as tf

SIZE = 1000
X = np.random.rand(SIZE, 1)
X = X*(math.pi/2.0)

start = time.time()
x1 = tf.placeholder(tf.float32, [SIZE, 1], name='x1-input')
x2 = tf.placeholder(tf.float32, [SIZE, 1], name='x2-input')
y1 = tf.sin(x1)
y2 = tf.sin(x2)
y = y1*y2

with tf.Session() as sess:
    init_op = tf.global_variables_initializer()
    sess.run(init_op)
    converter = tf.lite.TFLiteConverter.from_session(sess, [x1, x2], [y])
    tflite_model = converter.convert()
    open("/home/alcht0/share/project/tensorflow-v1.12.0/converted_model.tflite", "wb").write(tflite_model)

end = time.time()
print("2nd ", str(end - start))
轉化函數
主要遇到的問題是 tensorflow的變化實在太快,這些個轉化函數一直在變。位置也一直在變,現在參考 官方文檔,是按上面代碼中調用,否則就會報找不到 lite之類的錯誤。我現在PC上的 tensorflow Python版本是1.13,所以 lite已經在 contrib外面了,如果是以前的版本,要按文檔中下面這樣調用。
 
TensorFlow Version Python API
1.12 tf.contrib.lite.TFLiteConverter
1.9-1.11 tf.contrib.lite.TocoConverter
1.7-1.8 tf.contrib.lite.toco_convert

 

 

 

 

 

輸入參數shape

本來在本文件中為了給定的輸入數據大小自由,x1,x2shape會寫成[None, 1],但是如果這樣寫,轉化成tflite模型后會默認為[1,1],並不能自由接收數據大小,所以在這里要指定大小SIZE

x1 = tf.placeholder(tf.float32, [SIZE, 1], name='x1-input')

導入tflite模型

本來這部分應該是在ARM板子上做的,但是為了驗證tflite文件的可用性,我先在PC的Python上試驗。先上代碼:

import tensorflow as tf
import numpy as np
import math
import time

SIZE = 1000
X = np.random.rand(SIZE, 1, ).astype(np.float32)
X = X*(math.pi/2.0)

start = time.time()

interpreter = tf.lite.Interpreter(model_path="/home/alcht0/share/project/tensorflow-v1.12.0/converted_model.tflite")
interpreter.allocate_tensors()

input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

interpreter.set_tensor(input_details[0]['index'], X)
interpreter.set_tensor(input_details[1]['index'], X)

interpreter.invoke()

output_data = interpreter.get_tensor(output_details[0]['index'])
end = time.time()
print("1st ", str(end - start))
首先根據 tflite文件生成解析器,然后用 allocate_tensors()分配內存。將輸入通過 set_tensor傳入,然后調用 invoke()來真正運行。最后得到輸出。
Python跑的時候可以很清楚的看到 input_details的數據結構。官方的例子是只傳入一個數據,所以只需要取 input_details[0],而我傳入了2個輸入,所以需要設置2個。同時可以看到 input_details的2個數據的名字都是我在之前設置的 x1-inputx2-input,這樣非常好理解。
輸入參數類型
這里有個坑是輸入參數的類型一定要注意。我在生成模型的時候定義的輸入參數類型是 tf.float32,而在導入的時候如果直接是 X = np.random.rand(SIZE, 1, )的話,會報錯:
ValueError: Cannot set tensor: Got tensor of type 0 but expected type 1 for input 3

這里把通過astype(np.float32)把輸入參數指定為float32就OK了。

  • 操作不支持的坑
    可以從前面的代碼里看到我寫了兩個sin(),其實一開始是一個sin()一個cos()的,但是好像默認的tflite模型不支持cos()操作,無法生成,所以我只好暫時先只寫sin(),后面再研究怎么把cos()加上。

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM