Tensorflow中tflite權重參數提取與推理過程示意


1、引言

最近一段時間在對卷積神經網絡進行量化的過程中,閱讀了部分論文,其中對於谷歌在CVPR2018上發表的論文“Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference”印象深刻,在工程的應用上由於和tensorflow官方的代碼有密切相關,在實際的工程實踐上擁有一個良好的框架平台。首先筆者對谷歌發表的論文進行了詳細的解讀,同時對於其中推理部分原理進行了分析,了解到谷歌的量化方式存儲方式為輸入特征圖input和卷積核filter均為8bit無符號整形存儲方式,對於偏置bias為32bit有符號整形存儲。典型的應用是在MobileNet網絡上的量化,並且官方還提供了相應的模型文件,比如tflite。這種存儲方式其實在大部分的推理中已經可以實現足夠的精度優勢了,但是這里面存在着一些問題,正常來講tflite的運行是基於現有Tensorflow框架的,官方提供的示意demo包括了基於C++,Java,IOS等平台,依然沒能脫離已有的平台限制。

但是實際的應用中,比如在嵌入式微處理器或者微控制器,以及筆者一直研究的FPGA平台加速器等領域,不可能搭建一個Tensorflow框架來運行tflite模型,這樣在一定程度上限制了算法的應用領域。所以例如像意法半導體在STM32 Cube里面提供的Cube AI工具包,以及Xilinx提供的DPU完整開發框架,都在某些方面擴展了人工智能算法的適用范圍。那么如果從學術、教育、研究的角度去考慮這些問題,我們希望能夠基於現有的優秀成果,去理解內部的運行機理,從而提高我們對算法的深層次理解。

然而,目前從網上查閱到的資料里面,對於tflite權重數據提取方面的內容都比較零碎,也沒有一個比較系統的介紹,因此筆者結合前段時間的一些開發經驗,對tflite的數據提取和推理過程做一個簡單的演示與總結,同時也希望能夠起到一個拋磚引玉的作用。

本篇隨筆的主要內容如下:首先我們以MobileNetV1的tflite模型文件為例,介紹基於Tensorflow框架如何調用與運行;然后,我們介紹如何提取權重的參數,這里面包含了卷積核filter、偏置bias、量化系數(需要結合論文來說明);最后,我們再借助提取到的數據模擬MobileNet的前向推理,實現對輸入圖片的分類。

相關源碼與文件將在文末提供。

2、基於Tensorflow/Tf_nightly框架的tflite模型文件調用

先介紹下我們使用的模型文件,以MobileNetV1為例,Tensorflow官方提供的量化后模型文件有很多,詳細的信息可以參考https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet_v1.md,里面提供了不同尺度大小的性能和模型文件下載。筆者選用的是里面最簡單、規模最小的網絡MobileNet_v1_0.25_128_quant,這樣便於數據的加載和分析。

接着我們可以在Python里面直接調用tflite文件,對輸入圖片進行一次分類,對應的運行環境為Ubuntu18.04,Python中所用到的庫版本opencv==3.4.2.17,tensorflow==1.12.0 或者使用tf_nightly替代tensorflow。具體代碼如下:

 1 import time
 2 import cv2
 3 import numpy as np
 4 import tensorflow as tf
 5 
 6 model_path = "mobilenet_v1_0.25_128_quant.tflite"
 7 inter = tf.contrib.lite.Interpreter(model_path=model_path)
 8 inter.allocate_tensors()
 9 input_details = inter.get_input_details()
10 output_details = inter.get_output_details()
11 
12 img = cv2.imread('test2.png')
13 img = cv2.resize(img, (128, 128))
14 
15 img_exp = np.expand_dims(img, axis=0)
16 print(img_exp.shape)
17 
18 inter.set_tensor(input_details[0]['index'], img_exp)
19 
20 time_start =time.time()
21 inter.invoke()
22 time_end = time.time()
23 print(time_end-time_start)
24 
25 output_data = inter.get_tensor(output_details[0]['index'])
26 print(output_data.shape)
27 result = np.squeeze(output_data)
28 print('np.argmax(result):',np.argmax(result))

其中輸入圖片是test2.png,此處也可以用其他的圖片來代替,最終輸出的是圖片分類所需要的時間和分類結果。

此處筆者有個不明白的地方,通過查閱網絡資料,發現tflite是Tensorflow面向移動設備應用的數據格式,官方給出的應用示例包括Python、Java、C++等,也就是面向樹莓派(Linux)、安卓、蘋果等平台,貌似並不支持tflite在GPU上運行。截至目前,筆者使用的電腦端GPU和JetsonNano的嵌入式GPU均無法運行,只能在CPU上執行。

3、tflite權重參數提取

接下來就是如何提取tflite中經過訓練以后的參數了,在提取權重參數前需要了解Tensorflow在量化方面的相關理論知識,這里參考《Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference》這篇論文,其中對於8bit量化感知訓練給出了詳細的理論推導是具體實現過程。筆者在此處只對其前向推理部分做介紹,對於訓練部分沒做深入研究。

3.1 量化理論介紹(不關心可直接跳過) 

首先對量化的理論做一個簡單的介紹,假設輸入數據為A,權重數據為W,輸出數據為O,偏置為B,以上數據均為浮點型。那么根據公式可得:

O=A*W+B                                          (1)

上述公式可以分別對其進行量化操作,若采用定點量化方法,量化前與量化后數據如下所示:

A=(input-Z1)*S1

W=(weight-Z2)*S2                            (2)

O=(output-Z3)*S3

其中Z1、Z2、Z3分別為輸入、權重、輸出的零點偏移;S1、S2、S3分別為輸入、權重、輸出的尺度系數(縮放系數);input、weight、output分別為量化后的輸入、權重、輸出的結果,一般為8bit量化;

上述量化當然在一定程度上存在損失,但是隨着量化感知訓練可以將網絡最終的精度下降控制在很小的范圍內,甚至可以做到沒有精度損失。

那么將公式(2)代入到公式(1)中,並且對於公式(1)的B進行32bit量化,量化后結果表示為bias,縮放系數為S1*S2/S3,那么公式(1)可以表示為:

output = Z3+(S1*S2/S3)[(input-Z1)x(weight-Z2)+bias]                     (3)

而在Tensorflow的tflite網絡壓縮中,通常采用Relu6作為激活函數,那么input通常用無符號8bit,對應的取值范圍為0~255;weight通常用有符號8bit,對應的取值范圍為-128~127;同理output通常采用無符號8bit,對應取值范圍為0~255。

上述情況存在一種特例,也就是第一層網絡,輸入的為原始RGB圖像,需要轉換為有符號型數據,而其他網絡在Relu6激活函數的作用下,均可使用無符號型數據。因此,只有第一層網絡中Z1=128,而其他剩余網絡中Z1=0。對於每一層網絡的輸出,在Relu6激活函數的作用下,Z3=0。因此公式(3)可簡化為如下表達式:

output = (S1*S2/S3)[(input)x(weight-Z2)+bias]                                  (4)

如果對上述的尺度因子做統一表示,則公式(4)可進一步簡化為:

output = scale*[(input)x(weight-Z2)+bias] 

scale=S1*S2/S3                                                                                 (5)

而在文章《Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference》中,對於公式(5)中的scale進行分析,完成單層網絡的卷積操作后,在層間數據類型的轉換中,依然存在浮點計算操作,因此在文章中進一步采用了32bit定點加移位的表示方法,這樣對於不帶浮點處理單元的微處理,或對於計算資源敏感的FPGA單元,可以進一步提升計算性能。該方法筆者已在STM32和FPGA上進行了驗證,將在后續的博客中單獨對此進行介紹,同樣完成代碼的開源,和不同平台的性能測試結果對比。

3.2 權重數據提取

權重參數的提取需要借助工具netron,在此處非常感謝原作者能夠提供這么好的工具。安裝方法可以基於Python,也可直接在Windows端安裝。筆者采用Ubuntu下Python的pip安裝方法。在安裝完成以后進入命令行,輸入netron -b,此時不要退出命令行,瀏覽器將會自動打開一個網頁,如果網頁沒有打開,可以自行打開瀏覽器,輸入localhost:8080,進入網頁界面,並加載mobilenet_v1_0.25_128_quant.tflite文件。

單擊其中的逐個網絡層,截圖如下:

在圖中可以直接看到量化后的權重,同時也可以在網頁中直接將權重數據保存到本地。

當然上述方法一種比較簡單的做法,而筆者為了能夠做理論性驗證,需要復現網絡的模擬計算過程,因此需要做進一步的深入研究。

在上述截圖中需要注意的是每一層網絡中input和output對應的location編號,需要記下來,后續用得着。

具體python程序如下:

  1 # -*- coding: utf-8 -*-
  2 # @Time    : 2020.12.13
  3 # @Author  : wuruidong
  4 # @Email   : wuruidong@hotmail.com
  5 # @FileName: mobilenet_tf.py
  6 # @Software: python
  7 # @Cnblogs : https://www.cnblogs.com/ruidongwu
  8 
  9 import cv2
 10 import numpy as np
 11 import tensorflow as tf
 12 
 13 '''
 14 Library Version:
 15 python-opencv==3.4.2.17
 16 tensorflow==1.12.0 or tf_nightly
 17 
 18 '''
 19 
 20 '''
 21 Location number is obtained by netron (https://github.com/lutzroeder/netron).
 22 Thanks the authors for providing such a wonderful tool.
 23 # stage 1 (install): pip3 install netron
 24 # stage 2 (start with brower): netron -b
 25 # stage 3 (enter local ip): http://localhost:8080/
 26 # stage 4 (open tflite file): mobilenet_v1_0.25_128_quant.tflite
 27 # stage 5 (record location number)
 28 '''
 29 input_location =    np.array([88, 7,  33, 37, 39, 43, 45, 49, 51, 55, 57, 61, 63, 67, 69, 73, 75, 79, 81, 85, 9,  13, 15, 19, 21, 25, 27, 0], dtype=np.int)
 30 weight_location =   np.array([8,  35, 38, 41, 44, 47, 50, 53, 56, 59, 62, 65, 68, 71, 74, 77, 80, 83, 86, 11, 14, 17, 20, 23, 26, 29, 32, 3], dtype=np.int)
 31 bias_location =     np.array([6,  34, 36, 40, 42, 46, 48, 52, 54, 58, 60, 64, 66, 70, 72, 76, 78, 82, 84, 10, 12, 16, 18, 22, 24, 28, 30, 2], dtype=np.int)
 32 output_location =   np.array([7,  33, 37, 39, 43, 45, 49, 51, 55, 57, 61, 63, 67, 69, 73, 75, 79, 81, 85, 9,  13, 15, 19, 21, 25, 27, 31, 1], dtype=np.int)
 33 
 34 '''
 35 load tflite model from local file.
 36 '''
 37 def load_tflite(model_path=''):
 38     inter = tf.contrib.lite.Interpreter(model_path=model_path)
 39     #inter = tf.lite.Interpreter(model_path=model_path) # pip install tf-nightly
 40     inter.allocate_tensors()
 41     return inter
 42 
 43 '''
 44 load image with img_file name
 45 '''
 46 def load_img(img_file=''):
 47     img = cv2.imread(img_file)
 48     img = cv2.resize(img, (128, 128))
 49     img = np.expand_dims(img, axis=0)
 50     return img
 51 
 52 '''
 53 This function is network inference with tensorflow library.
 54 But it is a black box for education,
 55 and I want to analysis the principle of quantization with inference.
 56 If the filter/weight/bias/quantization could be exported in cunstom format with tables,
 57 so we can deploy user network or basic network on other platforms,
 58 not only Android/IOS/Raspberry,
 59 but also stm32/FPGA and so on.
 60 '''
 61 def tflite_inference(model, img):
 62     # get input node information
 63     input_details = model.get_input_details()
 64     # get output node information
 65     output_details = model.get_output_details()
 66     # set input data
 67     model.set_tensor(input_details[0]['index'], img)
 68     # start inference
 69     model.invoke()
 70     # get output data
 71     output_data = model.get_tensor(output_details[0]['index'])
 72     return output_data
 73 
 74 '''
 75 This function refers the paper of "Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference".
 76 Thanks for the contribution of tensorflow.
 77 
 78 In this function, I will implement the quantitative inference.
 79 According to the network structure of mobilenet_v1,
 80 this function supports depthwise/pointwise/standard convolution,
 81 but the stride is only support 1 and 2.
 82 
 83 The main principle is the following equation(1) in the reference paper:
 84 output = Z3+(S1*S2/S3)[(input-Z1)x(weight-Z2)+bias]     (1)
 85 
 86 "(input-Z1)x(weight-Z2)" is the operation of depthwise/pointwise/standard convolution,
 87 Zi is the value of zero offset.
 88 Generally, Z3=0, Z1=128(layer0) Z1=0(other layers), Z2>0.
 89 If the activation function is relu6, S1=S3.
 90 so the equation(1) can be written as:
 91 output = (S1*S2/S3)[(input)x(weight-Z2)+bias]           (2)
 92 
 93 If scale=(S1*S2/S3) or scale=S2, and the equation(1) can be simplified as:
 94 output = scale*[(input)x(weight-Z2)+bias]               (3)
 95 
 96 In equation(3), the data type of scale is float32, input is uin8(other layers) or int8(layer0), weight is uint8,
 97 (weight-Z2) is int16 or int32.
 98 '''
 99 def my_conv(model, input, layer_index, layer_type='depthwise', strides=1):
100     input_index = input_location[layer_index]
101     weight_index = weight_location[layer_index]
102     bias_index = bias_location[layer_index]
103     output_index = output_location[layer_index]
104 
105     # input_quant[0]=>S1, input_quant[1]=>Z1
106     input_quant = model._interpreter.TensorQuantization(int(input_index))
107     # img_tensor = input-Z1
108     img_tensor = input - tf.constant(input_quant[1], dtype=tf.float32)
109 
110     # weight_quant[0]=>S2, weight_quant[1]=>Z2
111     weight_quant = model._interpreter.TensorQuantization(int(weight_index))
112     t_w = model.get_tensor(int(weight_index))
113     t_w = np.transpose(t_w, (1, 2, 3, 0))
114     weight_tensor = tf.convert_to_tensor(t_w)
115     weight_tensor = tf.cast(weight_tensor, dtype=tf.float32)
116     # weight_tensor = weight-Z2
117     weight_tensor = weight_tensor - tf.constant(weight_quant[1], dtype=tf.float32)
118     # bias_tensor = bias
119     bias_tensor = tf.convert_to_tensor(model.get_tensor(int(bias_index)), dtype=tf.float32)
120     # output_quant[0]=>S3(S3=0), output_quant[1]=>Z3
121     output_quant = model._interpreter.TensorQuantization(int(output_index))
122     # scale=(S1*S2/S3) Note: If the activation function is relu6, then scale=S2.
123     scale = input_quant[0] * weight_quant[0] / output_quant[0]
124 
125     if layer_type=='depthwise':
126         conv_res = tf.nn.depthwise_conv2d(img_tensor, weight_tensor, strides=[1, strides, strides, 1], padding='SAME')
127     elif layer_type=='pointwise':
128         conv_res = tf.nn.conv2d(img_tensor, weight_tensor, strides=[1, 1, 1, 1], padding='SAME')
129     elif layer_type=='standard':
130         conv_res = tf.nn.conv2d(img_tensor, weight_tensor, strides=[1, strides, strides, 1], padding='SAME')
131     else:
132         print('layer_type = depthwise? pointwise? standard?')
133     conv_bias = tf.nn.bias_add(conv_res, bias_tensor)
134     conv_scale = conv_bias * tf.constant(scale, dtype=tf.float32)
135 
136     return tf.clip_by_value(tf.round(conv_scale), 0, 255)
137 
138 
139 '''
140 Classifier of MobileNet
141 '''
142 def my_fc(model, input, layer_index):
143     input_index = input_location[layer_index]
144     weight_index = weight_location[layer_index]
145     bias_index = bias_location[layer_index]
146     output_index = output_location[layer_index]
147 
148     weight_quant = model._interpreter.TensorQuantization(int(weight_index))
149     t_w = model.get_tensor(int(weight_index))
150     t_w = np.transpose(t_w, (1, 2, 3, 0))
151     weight_tensor = tf.convert_to_tensor(t_w)
152     weight_tensor = tf.cast(weight_tensor, dtype=tf.float32)
153     weight_tensor = weight_tensor - tf.constant(weight_quant[1], dtype=tf.float32)
154 
155     return tf.matmul(input, weight_tensor)
156 
157 
158 model = load_tflite("mobilenet_v1_0.25_128_quant.tflite")
159 img = load_img('test2.png')
160 
161 print('***********************TFLite inference**************************')
162 tf_res = tflite_inference(model, img)
163 tf_res = np.squeeze(tf_res)
164 print('TFLite result is', np.argmax(tf_res))
165 
166 print('**********Custom inference for principle verification************')
167 
168 layer_index=0
169 
170 img_tensor = tf.convert_to_tensor(img)
171 img_tensor = tf.cast(img_tensor, dtype=tf.float32)
172 conv0 = my_conv(model, img_tensor, layer_index, layer_type='standard', strides=2)
173 layer_index = layer_index+1
174 
175 conv1 = my_conv(model, conv0, layer_index, layer_type='depthwise', strides=1)
176 layer_index = layer_index+1
177 conv2 = my_conv(model, conv1, layer_index, layer_type='pointwise', strides=1)
178 layer_index = layer_index+1
179 
180 conv3 = my_conv(model, conv2, layer_index, layer_type='depthwise', strides=2)
181 layer_index = layer_index+1
182 conv4 = my_conv(model, conv3, layer_index, layer_type='pointwise', strides=1)
183 layer_index = layer_index+1
184 
185 conv5 = my_conv(model, conv4, layer_index, layer_type='depthwise', strides=1)
186 layer_index = layer_index+1
187 conv6 = my_conv(model, conv5, layer_index, layer_type='pointwise', strides=1)
188 layer_index = layer_index+1
189 
190 conv7 = my_conv(model, conv6, layer_index, layer_type='depthwise', strides=2)
191 layer_index = layer_index+1
192 conv8 = my_conv(model, conv7, layer_index, layer_type='pointwise', strides=1)
193 layer_index = layer_index+1
194 
195 conv9 = my_conv(model, conv8, layer_index, layer_type='depthwise', strides=1)
196 layer_index = layer_index+1
197 conv10 = my_conv(model, conv9, layer_index, layer_type='pointwise', strides=1)
198 layer_index = layer_index+1
199 
200 conv11 = my_conv(model, conv10, layer_index, layer_type='depthwise', strides=2)
201 layer_index = layer_index+1
202 conv12 = my_conv(model, conv11, layer_index, layer_type='pointwise', strides=1)
203 layer_index = layer_index+1
204 
205 conv13 = my_conv(model, conv12, layer_index, layer_type='depthwise', strides=1)
206 layer_index = layer_index+1
207 conv14 = my_conv(model, conv13, layer_index, layer_type='pointwise', strides=1)
208 layer_index = layer_index+1
209 
210 conv15 = my_conv(model, conv14, layer_index, layer_type='depthwise', strides=1)
211 layer_index = layer_index+1
212 conv16 = my_conv(model, conv15, layer_index, layer_type='pointwise', strides=1)
213 layer_index = layer_index+1
214 
215 conv17 = my_conv(model, conv16, layer_index, layer_type='depthwise', strides=1)
216 layer_index = layer_index+1
217 conv18 = my_conv(model, conv17, layer_index, layer_type='pointwise', strides=1)
218 layer_index = layer_index+1
219 
220 conv19 = my_conv(model, conv18, layer_index, layer_type='depthwise', strides=1)
221 layer_index = layer_index+1
222 conv20 = my_conv(model, conv19, layer_index, layer_type='pointwise', strides=1)
223 layer_index = layer_index+1
224 
225 conv21 = my_conv(model, conv20, layer_index, layer_type='depthwise', strides=1)
226 layer_index = layer_index+1
227 conv22 = my_conv(model, conv21, layer_index, layer_type='pointwise', strides=1)
228 layer_index = layer_index+1
229 
230 conv23 = my_conv(model, conv22, layer_index, layer_type='depthwise', strides=2)
231 layer_index = layer_index+1
232 conv24 = my_conv(model, conv23, layer_index, layer_type='pointwise', strides=1)
233 layer_index = layer_index+1
234 
235 conv25 = my_conv(model, conv24, layer_index, layer_type='depthwise', strides=1)
236 layer_index = layer_index+1
237 conv26 = my_conv(model, conv25, layer_index, layer_type='pointwise', strides=1)
238 layer_index = layer_index+1
239 
240 pooling_res = tf.nn.avg_pool(conv26, ksize=[1, 4, 4, 1], strides=[1, 4, 4, 1], padding="SAME")
241 pooling_res = tf.round(pooling_res)
242 
243 fc_res = my_fc(model, pooling_res, layer_index)
244 
245 with tf.Session() as sess:
246     layer_res = sess.run(fc_res)
247     print(layer_res.shape)
248     print('Custom result is', np.argmax(layer_res))

3.3 注意事項

實際的運行理論過程和3.1的分析基本一致。需要注意的是,在tensorflow的padding操作中,正常的padding是上下左右均添加,但是如果在stride為2時,padding並不是在輸入特征圖上添加一圈的數據,有可能是只有半圈,在tensorflow中的半圈padding中,針對於特征圖的右邊和下邊;而在pytorch中的半圈padding中,針對於特征圖的上邊和左邊。筆者一開始使用的pytorch函數進行前向推理的模擬,因為沒有注意到這個問題,導致最終分類的結果錯誤,最后才發現是tensorflow和pytorch中對與padding的操作方式不同所導致的。

4、總結

根據本文提供的Python腳本,可以很方便的完成權重數據的提取,同時對tflite,尤其是量化后數據的運行過程有一個原理性認識,這樣其實可以脫離原有的tensorflow框架,能夠將訓練后的神經網絡運行在任何的平台上。

后續筆者將繼續介紹tflite下mobilenet_v1網絡在STM32上的實現過程,雖然Tensorflow官方已經發布了相關的源碼,並包含有示例,但是官方采用的是C++的編程方法,不利於網絡的移植,筆者將分享自己的移植過程。當然啦,性能比不上ARM官方的CMSIS_NN庫,代碼里面還有待進一步的提升,筆者也希望能夠起到一個拋磚引玉的作用,希望能夠使得那些想在微處理器上運行神經網絡的開發者能夠多一種選擇。

最后附上源碼:點我下載


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM