假如想要在ARM板上用
根據前面所說,
tensorflow lite,那么意味着必須要把PC上的模型生成
tflite文件,然后在ARM上導入這個
tflite文件,通過解析這個文件來進行計算。
根據前面所說,
tensorflow的所有計算都會在內部生成一個圖,包括變量的初始化,輸入定義等,那么即便不是經過訓練的神經網絡模型,只是簡單的三角函數計算,也可以生成一個
tflite模型用於在
tensorflow lite上導入。所以,這里我就只做了簡單的
sin()計算來跑一編這個流程。
生成tflite模型
這部分主要是調用TFLiteConverter函數,直接生成tflite文件,不再通過pb文件轉化。
先上代碼:
import numpy as np import time import math import tensorflow as tf SIZE = 1000 X = np.random.rand(SIZE, 1) X = X*(math.pi/2.0) start = time.time() x1 = tf.placeholder(tf.float32, [SIZE, 1], name='x1-input') x2 = tf.placeholder(tf.float32, [SIZE, 1], name='x2-input') y1 = tf.sin(x1) y2 = tf.sin(x2) y = y1*y2 with tf.Session() as sess: init_op = tf.global_variables_initializer() sess.run(init_op) converter = tf.lite.TFLiteConverter.from_session(sess, [x1, x2], [y]) tflite_model = converter.convert() open("/home/alcht0/share/project/tensorflow-v1.12.0/converted_model.tflite", "wb").write(tflite_model) end = time.time() print("2nd ", str(end - start))
轉化函數
主要遇到的問題是
主要遇到的問題是
tensorflow的變化實在太快,這些個轉化函數一直在變。位置也一直在變,現在參考
官方文檔,是按上面代碼中調用,否則就會報找不到
lite之類的錯誤。我現在PC上的
tensorflow
Python版本是1.13,所以
lite已經在
contrib外面了,如果是以前的版本,要按文檔中下面這樣調用。
| TensorFlow Version | Python API |
| 1.12 | tf.contrib.lite.TFLiteConverter |
| 1.9-1.11 | tf.contrib.lite.TocoConverter |
| 1.7-1.8 | tf.contrib.lite.toco_convert |
輸入參數shape
本來在本文件中為了給定的輸入數據大小自由,x1,x2的shape會寫成[None, 1],但是如果這樣寫,轉化成tflite模型后會默認為[1,1],並不能自由接收數據大小,所以在這里要指定大小SIZE:
x1 = tf.placeholder(tf.float32, [SIZE, 1], name='x1-input')
導入tflite模型
本來這部分應該是在ARM板子上做的,但是為了驗證tflite文件的可用性,我先在PC的Python上試驗。先上代碼:
import tensorflow as tf import numpy as np import math import time SIZE = 1000 X = np.random.rand(SIZE, 1, ).astype(np.float32) X = X*(math.pi/2.0) start = time.time() interpreter = tf.lite.Interpreter(model_path="/home/alcht0/share/project/tensorflow-v1.12.0/converted_model.tflite") interpreter.allocate_tensors() input_details = interpreter.get_input_details() output_details = interpreter.get_output_details() interpreter.set_tensor(input_details[0]['index'], X) interpreter.set_tensor(input_details[1]['index'], X) interpreter.invoke() output_data = interpreter.get_tensor(output_details[0]['index']) end = time.time() print("1st ", str(end - start))
首先根據
用
輸入參數類型
tflite文件生成解析器,然后用
allocate_tensors()分配內存。將輸入通過
set_tensor傳入,然后調用
invoke()來真正運行。最后得到輸出。
用
Python跑的時候可以很清楚的看到
input_details的數據結構。官方的例子是只傳入一個數據,所以只需要取
input_details[0],而我傳入了2個輸入,所以需要設置2個。同時可以看到
input_details的2個數據的名字都是我在之前設置的
x1-input和
x2-input,這樣非常好理解。
這里有個坑是輸入參數的類型一定要注意。我在生成模型的時候定義的輸入參數類型是
tf.float32,而在導入的時候如果直接是
X = np.random.rand(SIZE, 1, )的話,會報錯:
ValueError: Cannot set tensor: Got tensor of type 0 but expected type 1 for input 3
這里把通過astype(np.float32)把輸入參數指定為float32就OK了。
- 操作不支持的坑
可以從前面的代碼里看到我寫了兩個sin(),其實一開始是一個sin()一個cos()的,但是好像默認的tflite模型不支持cos()操作,無法生成,所以我只好暫時先只寫sin(),后面再研究怎么把cos()加上。
