網上有很多關於tensorflow lite在安卓端部署的教程,但是大多只講如何把訓練好的模型部署到安卓端,不講如何訓練,而實際上在部署的時候,需要知道訓練模型時預處理的細節,這就導致了自己訓練的模型在部署到安卓端的時候出現各種問題。因此,本文會記錄從PC端訓練、導出到安卓端部署的各種細節。歡迎大家討論、指教。
PC端系統:Ubuntu14
tensorflow版本:tensroflow1.14
安卓版本:9.0
PC端訓練過程
數據集:自定義生成
訓練框架:tensorflow slim 關於tensorflow slim如何安裝,這里不再贅述,大家自行百度解決。
數據生成代碼:生成50000張28*28大小三通道的驗證碼圖片,共分10類,0-9,生成的數據保存在datasets/images/里面
# -*- coding: utf-8 -*- import cv2 import numpy as np from captcha.image import ImageCaptcha def generate_captcha(text='1'): """Generate a digit image.""" capt = ImageCaptcha(width=28, height=28, font_sizes=[24]) image = capt.generate_image(text) image = np.array(image, dtype=np.uint8) return image if __name__ == '__main__': output_dir = './datasets/images/' for i in range(50000): label = np.random.randint(0, 10) image = generate_captcha(str(label)) image_name = 'image{}_{}.jpg'.format(i+1, label) output_path = output_dir + image_name cv2.imwrite(output_path, image)
訓練:本次訓練我用tensorflow slim 搭建了一個七層卷積的網絡,最后測試准確率在96%~99%左右,模型1.2M,適合在移動端部署。訓練的時候我做了兩點工作
1、指明了模型的輸入和輸出節點的名字,PC端部署測試模型的時候要用到,也便於快速確定模型的輸出數據到底是什么格式,移動端代碼要與其保持一致
inputs = tf.placeholder(tf.float32, shape=[None, 28, 28, 3], name='inputs') ....... ....... prob_ = tf.identity(prob, name='prob')
2、訓練結束的時候直接把模型保存成PB格式
constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ['inputs','prob']) #訓練完畢直接把模型保存為PB格式 with tf.gfile.FastGFile('model3.pb', mode='wb') as f: #模型的名字是model.pb f.write(constant_graph.SerializeToString())
訓練代碼如下
# -*- coding: utf-8 -*- """Train a CNN model to classifying 10 digits. Example Usage: --------------- python3 train.py \ --images_path: Path to the training images (directory). --model_output_path: Path to model.ckpt. """ import cv2 import glob import numpy as np import os import tensorflow as tf import model from tensorflow.python.framework import graph_util flags = tf.app.flags flags.DEFINE_string('images_path', None, 'Path to training images.') flags.DEFINE_string('model_output_path', None, 'Path to model checkpoint.') FLAGS = flags.FLAGS def get_train_data(images_path): """Get the training images from images_path. Args: images_path: Path to trianing images. Returns: images: A list of images. lables: A list of integers representing the classes of images. Raises: ValueError: If images_path is not exist. """ if not os.path.exists(images_path): raise ValueError('images_path is not exist.') images = [] labels = [] images_path = os.path.join(images_path, '*.jpg') count = 0 for image_file in glob.glob(images_path): count += 1 if count % 100 == 0: print('Load {} images.'.format(count)) image = cv2.imread(image_file) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # Assume the name of each image is imagexxx_label.jpg label = float(image_file.split('_')[-1].split('.')[0]) images.append(image) labels.append(label) images = np.array(images) labels = np.array(labels) return images, labels def next_batch_set(images, labels, batch_size=128): """Generate a batch training data. Args: images: A 4-D array representing the training images. labels: A 1-D array representing the classes of images. batch_size: An integer. Return: batch_images: A batch of images. batch_labels: A batch of labels. """ indices = np.random.choice(len(images), batch_size) batch_images = images[indices] batch_labels = labels[indices] return batch_images, batch_labels def main(_): inputs = tf.placeholder(tf.float32, shape=[None, 28, 28, 3], name='inputs') labels = tf.placeholder(tf.int32, shape=[None], name='labels') cls_model = model.Model(is_training=True, num_classes=10) preprocessed_inputs = cls_model.preprocess(inputs)#預處理 prediction_dict = cls_model.predict(preprocessed_inputs) loss_dict = cls_model.loss(prediction_dict, labels) loss = loss_dict['loss'] postprocessed_dict = cls_model.postprocess(prediction_dict) classes = postprocessed_dict['classes'] prob = postprocessed_dict['prob'] classes_ = tf.identity(classes, name='classes') prob_ = tf.identity(prob, name='prob') acc = tf.reduce_mean(tf.cast(tf.equal(classes, labels), 'float')) global_step = tf.Variable(0, trainable=False) learning_rate = tf.train.exponential_decay(0.05, global_step, 150, 0.9) optimizer = tf.train.MomentumOptimizer(learning_rate, 0.9) train_step = optimizer.minimize(loss, global_step) saver = tf.train.Saver() images, targets = get_train_data(FLAGS.images_path) init = tf.global_variables_initializer() with tf.Session() as sess: sess.run(init) for i in range(6000): batch_images, batch_labels = next_batch_set(images, targets) train_dict = {inputs: batch_images, labels: batch_labels} sess.run(train_step, feed_dict=train_dict) loss_, acc_,prob__,classes__ = sess.run([loss, acc, prob_,classes_], feed_dict=train_dict) train_text = 'step: {}, loss: {}, acc: {},classes:{}'.format( i+1, loss_, acc_,classes__) print(train_text) print (prob__) saver.save(sess, FLAGS.model_output_path) constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ['inputs','prob']) #訓練完畢直接把模型保存為PB格式 with tf.gfile.FastGFile('model3.pb', mode='wb') as f: #模型的名字是model.pb f.write(constant_graph.SerializeToString()) if __name__ == '__main__': tf.app.run()
這里尤其要注意,訓練的時候圖片是否做過預處理,比如減去均值和除法歸一化操作,因為移動端需要保持和訓練時候一樣的操作。我的在訓練的時候,預處理工作中包含了減去均值和除法歸一化,並且把這兩個OP打包直接放進了模型里面,也就是說圖片數據進入模型之后會先進行預處理然后再進行正式的卷積等系列操作。所以,移動端的數據不需要單獨寫預處理的代碼。很多時候,導出模型的時候並沒有把預處理操作打包進模型,所以移動端要單獨寫幾行關於減去均值和歸一化的代碼,然后再把數據送到分類模型當中。
另外一種把ckpt模型導出為pb模型的方式,代碼如下
import tensorflow as tf from tensorflow.python.framework import graph_util def freeze_graph(input_checkpoint,output_graph): ''' :param input_checkpoint: :param output_graph: PB模型保存路徑 :return: ''' # checkpoint = tf.train.get_checkpoint_state(model_folder) #檢查目錄下ckpt文件狀態是否可用 # input_checkpoint = checkpoint.model_checkpoint_path #得ckpt文件路徑 # 指定輸出的節點名稱,該節點名稱必須是原模型中存在的節點 #input_node_names = "inputs" output_node_names = "inputs,classes" saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True) graph = tf.get_default_graph() # 獲得默認的圖 input_graph_def = graph.as_graph_def() # 返回一個序列化的圖代表當前的圖 with tf.Session() as sess: saver.restore(sess, input_checkpoint) #恢復圖並得到數據 output_graph_def = graph_util.convert_variables_to_constants( # 模型持久化,將變量值固定 sess=sess, input_graph_def=input_graph_def,# 等於:sess.graph_def output_node_names=output_node_names.split(","))# 如果有多個輸出節點,以逗號隔開 with tf.gfile.GFile(output_graph, "wb") as f: #保存模型 f.write(output_graph_def.SerializeToString()) #序列化輸出 print("%d ops in the final graph." % len(output_graph_def.node)) #得到當前圖有幾個操作節點 # for op in graph.get_operations(): # print(op.name, op.values()) # 輸入ckpt模型路徑 input_checkpoint='model/model.ckpt' # 輸出pb模型的路徑 out_pb_path="frozen_model.pb" # 調用freeze_graph將ckpt轉為pb freeze_graph(input_checkpoint,out_pb_path)
把PB模型導出為tflite格式代碼
import tensorflow as tf #把pb文件路徑改成自己的pb文件路徑即可 path = "model2.pb" #如果是不知道自己的模型的輸入輸出節點,建議用tensorboard做可視化查看計算圖,計算圖里有輸入輸出的節點名稱 inputs = ["inputs"] outputs = ["prob"] #轉換pb模型到tflite模型 converter = tf.lite.TFLiteConverter.from_frozen_graph(path, inputs, outputs) #converter.post_training_quantize = True tflite_model = converter.convert() open("model3.tflite", "wb").write(tflite_model)
還有另外一種利用bazel把模型導出為tflite的辦法
進入tensorflow源碼目錄,兩步編譯
bazel build tensorflow/python/tools:freeze_graph
bazel build tensorflow/lite/toco:toco
./bazel-bin/tensorflow/contrib/lite/toco/toco
--input_file=/media/bayes/69da5b29-ae56-4feb-93a1-2ce24323aa78/project/model2.pb
--output_file=/media/bayes/69da5b29-ae56-4feb-93a1-2ce24323aa78/project/model2.tflite
--input_format=TENSORFLOW_GRAPHDEF
--output_format=TFLITE
--inference_type=FLOAT
--input_shape=1,28,28,3
--input_array=inputs
--output_array=prob
PB模型測試模型准確率
# -*- coding: utf-8 -*- """Evaluate the trained CNN model. Example Usage: --------------- python3 infrence_pb.py \ --frozen_graph_path: Path to model frozen graph. """ import numpy as np import tensorflow as tf from captcha.image import ImageCaptcha flags = tf.app.flags flags.DEFINE_string('frozen_graph_path', None, 'Path to model frozen graph.') FLAGS = flags.FLAGS def generate_captcha(text='1'): capt = ImageCaptcha(width=28, height=28, font_sizes=[24]) image = capt.generate_image(text) image = np.array(image, dtype=np.uint8) return image def main(_): model_graph = tf.Graph() with model_graph.as_default(): od_graph_def = tf.GraphDef() with tf.gfile.GFile(FLAGS.frozen_graph_path, 'rb') as fid: serialized_graph = fid.read() od_graph_def.ParseFromString(serialized_graph) tf.import_graph_def(od_graph_def, name='') with model_graph.as_default(): with tf.Session(graph=model_graph) as sess: inputs = model_graph.get_tensor_by_name('inputs:0') classes = model_graph.get_tensor_by_name('classes:0') prob = model_graph.get_tensor_by_name('prob:0') for i in range(10): label = np.random.randint(0, 10) image = generate_captcha(str(label)) image = image_np = np.expand_dims(image, axis=0) predicted_label,probs = sess.run([classes,prob], feed_dict={inputs: image_np}) print(predicted_label, ' vs ', label) print(probs) if __name__ == '__main__': tf.app.run()
tflite格式測試模型准確率
# -*- coding:utf-8 -*- import os import cv2 import numpy as np import time import tensorflow as tf test_image_dir = './test_images/' #model_path = "./model/quantize_frozen_graph.tflite" model_path = "./model3.tflite" # Load TFLite model and allocate tensors. interpreter = tf.lite.Interpreter(model_path=model_path) interpreter.allocate_tensors() # Get input and output tensors. input_details = interpreter.get_input_details() print(str(input_details)) output_details = interpreter.get_output_details() print(str(output_details)) #with tf.Session( ) as sess: if 1: file_list = os.listdir(test_image_dir) model_interpreter_time = 0 start_time = time.time() # 遍歷文件 for file in file_list: print('=========================') full_path = os.path.join(test_image_dir, file) print('full_path:{}'.format(full_path)) img = cv2.imread(full_path ) res_img = cv2.resize(img,(28,28),interpolation=cv2.INTER_CUBIC) # 變成長784的一維數據 #new_img = res_img.reshape((784)) new_img = np.array(res_img, dtype=np.uint8) # 增加一個維度,變為 [1, 784] image_np_expanded = np.expand_dims(new_img, axis=0) image_np_expanded = image_np_expanded.astype('float32') # 類型也要滿足要求 # 填裝數據 model_interpreter_start_time = time.time() interpreter.set_tensor(input_details[0]['index'], image_np_expanded) # 注意注意,我要調用模型了 interpreter.invoke() output_data = interpreter.get_tensor(output_details[0]['index']) model_interpreter_time += time.time() - model_interpreter_start_time # 出來的結果去掉沒用的維度 result = np.squeeze(output_data) print('result:{}'.format(result)) #print('result:{}'.format(sess.run(output, feed_dict={newInput_X: image_np_expanded}))) # 輸出結果是長度為10(對應0-9)的一維數據,最大值的下標就是預測的數字 #print('result:{}'.format( (np.where(result==np.max(result)))[0][0] )) used_time = time.time() - start_time print('used_time:{}'.format(used_time)) print('model_interpreter_time:{}'.format(model_interpreter_time))
模型訓練好以后,接下來要把模型部署到安卓端,其實這步很簡單,只要替換安卓代碼相應部分即可,安卓代碼我會上傳到CSDN,大家按需下載即可。那么主要留意更改哪些代碼呢
#模型的輸入大小
private int[] ddims = {1, 3, 28, 28};
#模型的名稱
private static final String[] PADDLE_MODEL = {
"model3",
"mobilenet_quant_v1_224",
"mobilenet_v1_1.0_224",
"mobilenet_v2"
};
#標簽的名稱
BufferedReader reader = new BufferedReader(new InputStreamReader(assetManager.open("cacheLabel1.txt")));
#模型輸出的數據類型,在PC端可以清楚地看到
float[][] labelProbArray = new float[1][10];
#輸入數據預處理工作是否已經包含在模型里面
// imgData.putFloat(((((val >> 16) & 0xFF) - 128f) / 128f));
// imgData.putFloat(((((val >> 8) & 0xFF) - 128f) / 128f));
// imgData.putFloat((((val & 0xFF) - 128f) / 128f));
imgData.putFloat(((val >> 16) & 0xFF) );
imgData.putFloat(((val >> 8) & 0xFF) );
imgData.putFloat((val & 0xFF) );
留一張測試圖片,大家可以拿去測試,正確結果應該是0.0,安卓代碼地址是這里,CSDN下載請搜索 anquangan
查看PB模型節點代碼
#coding:utf-8 import tensorflow as tf from tensorflow.python.framework import graph_util tf.reset_default_graph() # 重置計算圖 output_graph_path = 'model3.pb' with tf.Session() as sess: tf.global_variables_initializer().run() output_graph_def = tf.GraphDef() # 獲得默認的圖 graph = tf.get_default_graph() with open(output_graph_path, "rb") as f: output_graph_def.ParseFromString(f.read()) _ = tf.import_graph_def(output_graph_def, name="") # 得到當前圖有幾個操作節點 print("%d ops in the final graph." % len(output_graph_def.node)) tensor_name = [tensor.name for tensor in output_graph_def.node] print(tensor_name) print('---------------------------') # 在log_graph文件夾下生產日志文件,可以在tensorboard中可視化模型 #summaryWriter = tf.summary.FileWriter('log_graph/', graph) for op in graph.get_operations(): # print出tensor的name和值 print(op.name, op.values())