記錄一下做數字儀表檢測項目的過程,會附帶部分代碼
業務背景:四塊儀表,每塊表界面是4位(紅色)數字,即要檢測識別4個4位數字。在檢測界面還有三個燈,三個燈都是3中顏色,紅色、黃色、綠色。
我要做的就是實時的檢測出4個4位數字具體的數值,並且對3個燈進行分類。
解決思路:首先在攝像頭所拍攝到的界面中定位到數字、燈所在的區域,然后進行識別或者分類。
解決方法:
數字識別,有以下解決方法:
一、使用模板匹配的方法,因為數字都是規范格式的,所以使用模板匹配可以保證准確率,但是速度會比較慢;
二、在定位到數字區域后,進行字符分割,優點是字符分割后可以使用很簡單的神經網絡就可以進行數字0-9的分類,缺點是要一個一個數字識別;
三、直接使用CNN網絡進行分類,類似驗證碼識別,為了能夠將驗證碼圖片的文本信息輸入到卷積神經網絡模型里面去訓練,需要將文本信息向量化
編碼,參見https://my.oschina.net/u/876354/blog/3048523這篇博客。優點是使用的CNN網絡很簡單,缺點是需要大量的訓練數據,否則模型預測效果很差;
四、使用cnn+lstm+ctc這種比較成熟的深度學習在ocr的應用的模型組合。這里我采用的是第四種方法,模型有比較好的魯棒性,由於ctc的存在,
將來換成不定長的數字儀表也可以識別。具體原理參見https://my.oschina.net/u/876354/blog/3070699。這里我使用的網絡是lstm+ctc,之所以沒有使用
cnn是因為,也使用了“有色燈分類”中基於顏色提取4位數字目標,直接確定了4位數字所在的區域,在訓練數據采集時,得到的即是使用該方法攝像頭
采集到的數字圖片,而在預測的過程中,攝像頭拍攝的區域也會自動分隔出4張小的數字圖片送入模型進行預測。基於顏色提取的目標比較穩定,因此沒有
使用cnn進行特征提取。
有色燈分類,有以下解決方法:
一、使用深度學習的目標檢測方法,有one-stage和two-stage兩種,比較經典的是fast-rcnn和ssd、yolo等,但是由於這里只是3種燈的分類,業務
場景很簡單,這里不使用這種方法。
二、使用opencv中基於顏色提取目標,可以在攝像頭具體采集數字儀表和有色燈的場景下,提取有色燈3種顏色的hsv色彩空間的信息,然后使用
opencv中的一系列API進行目標提取和分類。(inrange得到二值圖像,濾波,腐蝕,膨脹,得到梯度,根據梯度找到目標輪廓)。我使用的是第二種方法。
代碼部分:
1、基於顏色提取數字所在的矩形框

注意:讀取圖片,將RGB裝換成hsv色彩空間,基於上圖hsv的色彩值,利用opencv下面的api進行目標顏色的提取,從黑到紫都可以提取。核心api是cv2.inRange()
def extrace_object_demo(src): img = cv2.imread(src) # 1.將RGB裝換成hsv色彩空間 hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV) # # 通道數是 3 img_binary = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) # 通道數是 1 # 2.定義數組,說明你要提取(過濾)的顏色目標, # 三通道,所以是三個參數 # 紅色 lower_hsv_g = np.array([156, 43, 46]) upper_hsv_g = np.array([180, 255, 255]) # 3.進行過濾,提取,得到二值圖像 mask_red = cv2.inRange(hsv, lower_hsv_g, upper_hsv_g) # 通道數是 1 # 合並展示 res = np.hstack((img_binary, mask_red)) cv2.imshow("res", res) cv2.waitKey(0) cv2.destroyAllWindows() return mask_red
原圖:

基於紅色得到的二值圖:(截圖,請忽略圖像size改變)

2、基於二值圖得到顏色目標的梯度
注意:高斯濾波的api采用的濾波器這里選的是3x3的,可以改其他尺寸,建議奇數,同時執行濾波的次數可以改變。
同理,腐蝕、膨脹的濾波器和濾波次數都可以改。
def img_preprocessing(src): # 調用基於顏色過濾的函數 mask_red = extrace_object_demo(src) # 高斯濾波,去噪聲 gaussian = cv2.GaussianBlur(mask_red, (3, 3), 1) # 腐蝕 kernel = np.ones((5, 5), np.uint8) erosion = cv2.erode(gaussian, kernel, iterations = 1) # 膨脹 dige_dilate = cv2.dilate(erosion, kernel, iterations = 1) # (形態學)梯度運算 = 膨脹運算 - 腐蝕運算 gradient = cv2.morphologyEx(dige_dilate, cv2.MORPH_GRADIENT, kernel) return gradient
提取到的梯度長這個樣子,有點不平整,沒關系,不影響最終的效果:

3、利用opencv的api,基於梯度找到其矩形輪廓
注意:這里num=4可以修改,如果以后要提取的矩形框目標是3個,則改為3,以此類推。
注意調用cv2.rectangle()這個api的時候,傳入的img應該是原圖,這樣就會在原圖上畫矩形框了。其他的api都是opencv提供好的,學會其用法正確傳參就可以。
def get_box(gt, num=4): # 對前面得到的梯度,find其輪廓 contours, _ = cv2.findContours(gradient, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE) # 創建一個列表,用來保存矩形框坐標值 list_box = [] # 取前num個面積最大的 cnt = sorted(contours, key=cv2.contourArea, reverse=True)[:num] for c in cnt: # 得到坐標 x, y, w, h = cv2.boundingRect(c) # 存起來 list_box.append((x, y, w, h)) # 將矩形畫出來 draw_img = cv2.rectangle(img, (x, y), (x+w, y+h), (0, 255, 0), 1) cv_show("draw_img", draw_img) # 返回坐標列表 return list(set(list_box)) # 有重復,去重
返回的坐標列表一般是這個樣子: [(83, 252, 154, 67), (66, 50, 173, 76), (374, 51, 183, 78), (366, 265, 174, 71)]
其效果圖:

4、利用上述返回的坐標,可以在原圖依次裁剪4張數字表的小圖片,進而可以滿足后續的訓練以及其他工作。
def create_crapimg(raw_img, list_box): img = cv2.imread(raw_img) # 新建一個列表,用於保存裁剪下來的圖片 list_crap = [] # 把矩形框對應的目標區域圖片裁剪出來 for i,box in enumerate(list_box): x, y, w, h = box # 獲得裁剪圖片 img_crap = img[y:y+h, x:x+w] # 修改圖片形狀 img_crap = cv2.resize(img_crap, (256, 32), 3) # 保存裁剪后的圖片 list_crap.append(img_crap) return list_crap
得到的圖片是這個樣子的:(文件名后面會說為什么命名成這樣)

總結:到此為止,關於數字儀表基於顏色提取就做完了,后面就是如何訓練和使用模型了。
關於模型、訓練相關代碼:
注意:直接運行下面代碼中的train()即可,保證各個路徑對即可,注意將每個圖片數據命名為:label_隨便.jpg,比如數字是1233,則可以命名為1233_隨便起一個名字.jpg。
這個因為下面的函數會讀取文件名,然后將label轉化成稀疏矩陣節約存儲空間,方便訓練。
模型的代碼是tensorflow封裝好的,即
# 定義LSTM網絡 cell = tf.contrib.rnn.LSTMCell(num_hidden, state_is_tuple=True) # LSTM cell中的block數量 stack = tf.contrib.rnn.MultiRNNCell([cell] * num_layers, state_is_tuple=True) outputs, _ = tf.nn.dynamic_rnn(cell, inputs, seq_len, dtype=tf.float32)
這里的LSTMCell、MultiRNNCell是可以修改的,可以改為rnn模塊下封裝的其他cell看看效果會不會更好,同時num_hidden和num_layers也是可以改的(2的倍數或次方),越大則模型越復雜。
#coding:utf-8 # 基於 lstm ctc 訓練識別不定長的文字 import numpy as np import cv2 import os import tensorflow as tf import random import time import datetime # from captcha.image import ImageCaptcha from PIL import Image, ImageFont, ImageDraw os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # 定義一些常量 # 元數據集 DIGITS = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'] # 圖片大小 OUTPUT_SHAPE = (32, 256) # 訓練最大輪次 num_epochs = 10000 num_hidden = 128 num_layers = 2 num_classes = len(DIGITS) + 1 # 初始化學習速率 INITIAL_LEARNING_RATE = 1e-3 DECAY_STEPS = 5000 REPORT_STEPS = 100 LEARNING_RATE_DECAY_FACTOR = 0.9 MOMENTUM = 0.9 BATCHES = 10 BATCH_SIZE = 64 TRAIN_SIZE = BATCHES * BATCH_SIZE # # 命令行參數 # # 定義model訓練的步數 step # tf.app.flags.DEFINE_integer("max_step", 0, "訓練模型的步數") # # 定義model的路徑 load + 名字 # tf.app.flags.DEFINE_string("model_dir", " ", "模型保存的路徑+模型名字") # # 獲取上述二者, 在運行的時候指定--->下面的參數要修改對應的FLAGS.max_step和FLAGS.model_dir # FLAGS = tf.app.flags.FLAGS # 命令行指令, 一定要寫模型名字。。。 # python xx.py --max_step=xx --load="xx+模型名字" data_dir = './tmp/train_data/' model_dir = './tmp/train_data_model/' # 稀疏矩陣轉序列 def decode_a_seq(indexes, spars_tensor): decoded = [] for m in indexes: str = DIGITS[spars_tensor[1][m]] decoded.append(str) return decoded def decode_sparse_tensor(sparse_tensor): decoded_indexes = list() current_i = 0 current_seq = [] for offset, i_and_index in enumerate(sparse_tensor[0]): i = i_and_index[0] if i != current_i: decoded_indexes.append(current_seq) current_i = i current_seq = list() current_seq.append(offset) decoded_indexes.append(current_seq) result = [] for index in decoded_indexes: result.append(decode_a_seq(index, sparse_tensor)) return result # 准確性評估 # 輸入:預測結果序列 decoded_list ,目標序列 test_targets # 返回:准確率 def report_accuracy(decoded_list, test_targets): original_list = decode_sparse_tensor(test_targets) detected_list = decode_sparse_tensor(decoded_list) # 正確數量 true_numer = 0 # 預測序列與目標序列的維度不一致,說明有些預測失敗,直接返回 if len(original_list) != len(detected_list): print("len(original_list)", len(original_list), "len(detected_list)", len(detected_list), " test and detect length desn't match") return # 比較預測序列與結果序列是否一致,並統計准確率 print("T/F: original(length) <-------> detectcted(length)") for idx, number in enumerate(original_list): detect_number = detected_list[idx] hit = (number == detect_number) print(hit, number, "(", len(number), ") <-------> ", detect_number, "(", len(detect_number), ")") if hit: true_numer = true_numer + 1 accuracy = true_numer * 1.0 / len(original_list) print("Test Accuracy:", accuracy) return accuracy # 轉化一個序列列表為稀疏矩陣 def sparse_tuple_from(sequences, dtype=np.int32): indices = [] values = [] for n, seq in enumerate(sequences): indices.extend(zip([n] * len(seq), range(len(seq)))) values.extend(seq) indices = np.asarray(indices, dtype=np.int64) values = np.asarray(values, dtype=dtype) shape = np.asarray([len(sequences), np.asarray(indices).max(0)[1] + 1], dtype=np.int64) return indices, values, shape # 將文件和標簽讀到數組 def get_file_text_array(): file_name_array=[] text_array=[] for parent, dirnames, filenames in os.walk(data_dir): file_name_array=filenames for f in file_name_array: text = f.split('_')[0] text_array.append(text) return file_name_array,text_array # 生成一個訓練batch def get_next_batch(file_name_array, text_array, batch_size=128): inputs = np.zeros([batch_size, OUTPUT_SHAPE[1], OUTPUT_SHAPE[0]]) codes = [] # 獲取訓練樣本 for i in range(batch_size): index = random.randint(0, len(file_name_array) - 1) image = cv2.imread(data_dir + file_name_array[index]) image = cv2.resize(image, (OUTPUT_SHAPE[1], OUTPUT_SHAPE[0]), 3) image = cv2.cvtColor(image,cv2.COLOR_RGB2GRAY) text = text_array[index] inputs[i, :] = np.transpose(image.reshape((OUTPUT_SHAPE[0], OUTPUT_SHAPE[1]))) codes.append(list(text)) targets = [np.asarray(i) for i in codes] sparse_targets = sparse_tuple_from(targets) seq_len = np.ones(inputs.shape[0]) * OUTPUT_SHAPE[1] return inputs, sparse_targets, seq_len def get_train_model(): inputs = tf.placeholder(tf.float32, [None, None, OUTPUT_SHAPE[0]]) # old targets = tf.sparse_placeholder(tf.int32) seq_len = tf.placeholder(tf.int32, [None]) # 定義LSTM網絡 cell = tf.contrib.rnn.LSTMCell(num_hidden, state_is_tuple=True) # LSTM cell中的block數量 stack = tf.contrib.rnn.MultiRNNCell([cell] * num_layers, state_is_tuple=True) outputs, _ = tf.nn.dynamic_rnn(cell, inputs, seq_len, dtype=tf.float32) shape = tf.shape(inputs) batch_s, max_timesteps = shape[0], shape[1] outputs = tf.reshape(outputs, [-1, num_hidden]) W = tf.Variable(tf.truncated_normal([num_hidden, num_classes], stddev=0.1), name="W") b = tf.Variable(tf.constant(0., shape=[num_classes]), name="b") logits = tf.matmul(outputs, W) + b logits = tf.reshape(logits, [batch_s, -1, num_classes]) # 轉置矩陣 logits = tf.transpose(logits, (1, 0, 2)) return logits, inputs, targets, seq_len, W, b def train(): # with tf.variable_scope("train"): # 獲取訓練樣本數據 file_name_array, text_array = get_file_text_array() # 定義學習率 global_step = tf.Variable(0, trainable=False) learning_rate = tf.train.exponential_decay(INITIAL_LEARNING_RATE, global_step, DECAY_STEPS, LEARNING_RATE_DECAY_FACTOR, staircase=True) # 獲取網絡結構 logits, inputs, targets, seq_len, W, b = get_train_model() # 設置損失函數 loss = tf.nn.ctc_loss(labels=targets, inputs=logits, sequence_length=seq_len) cost = tf.reduce_mean(loss) # 設置優化器 optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(loss, global_step=global_step) decoded, log_prob = tf.nn.ctc_beam_search_decoder(logits, seq_len, merge_repeated=False) acc = tf.reduce_mean(tf.edit_distance(tf.cast(decoded[0], tf.int32), targets)) init = tf.global_variables_initializer() config = tf.ConfigProto() config.gpu_options.allow_growth = True with tf.Session() as session: session.run(init) saver = tf.train.Saver(tf.global_variables(), max_to_keep=10) # saver.restore(session, tf.train.latest_checkpoint(model_dir)) for curr_epoch in range(num_epochs): train_cost = 0 train_ler = 0 # todo for batch in range(BATCHES): # for batch in range(FLAGS.max_step): # 訓練模型 train_inputs, train_targets, train_seq_len = get_next_batch(file_name_array, text_array, BATCH_SIZE) feed = {inputs: train_inputs, targets: train_targets, seq_len: train_seq_len} b_loss, b_targets, b_logits, b_seq_len, b_cost, steps, _ = session.run( [loss, targets, logits, seq_len, cost, global_step, optimizer], feed) # 評估模型 if steps > 0 and steps % REPORT_STEPS == 0: test_inputs, test_targets, test_seq_len = get_next_batch(file_name_array, text_array, BATCH_SIZE) test_feed = {inputs: test_inputs,targets: test_targets,seq_len: test_seq_len} dd, log_probs, accuracy = session.run([decoded[0], log_prob, acc], test_feed) report_accuracy(dd, test_targets) # 保存識別模型 save_path = saver.save(session, model_dir + "lstm_ctc_model.ctpk", global_step=steps) # save_path = saver.save(session, FLAGS.model_dir, global_step=steps) c = b_cost train_cost += c * BATCH_SIZE train_cost /= TRAIN_SIZE # 計算 loss train_inputs, train_targets, train_seq_len = get_next_batch(file_name_array, text_array, BATCH_SIZE) val_feed = {inputs: train_inputs,targets: train_targets,seq_len: train_seq_len} val_cost, val_ler, lr, steps = session.run([cost, acc, learning_rate, global_step], feed_dict=val_feed) # log = "{} Epoch {}/{}, steps = {}, train_cost = {:.3f}, val_cost = {:.3f}" log = "{} Epoch {}, steps = {}, train_cost = {:.3f}, val_cost = {:.3f}" print(log.format(curr_epoch + 1, num_epochs, steps, train_cost, val_cost))
預測的代碼:
這里之所以把加載模型單獨封裝成一個函數,是為了提高速度,獲取網絡結構值只加載一次。因為后面是在攝像頭采集圖像,一幀一幀進行預測,因此多次加載模型
會很浪費時間,同時在預測的主循環中tensorflow的session也只加載一次,可以大大提高時間。(session很占用資源)
# LSTM+CTC 文字識別能力封裝 # 加載模型 def load_model(): # 獲取網絡結構 tf.reset_default_graph() logits, inputs, targets, seq_len, W, b = get_train_model() decoded, log_prob = tf.nn.ctc_beam_search_decoder(logits, seq_len, merge_repeated=False) saver = tf.train.Saver() # sess = tf.Session() # # 加載模型 # saver.restore(sess, tf.train.latest_checkpoint(model_dir)) return saver, inputs, seq_len, decoded, log_prob # 輸入:圖片 # 輸出:識別結果文字 def predict(images_path, saver, inputs, seq_len, decoded, log_prob, sess): # 加載模型 # saver.restore(sess, tf.train.latest_checkpoint(model_dir)) # 圖像預處理 result_dict = {} for i, image in images_path.items(): image = cv2.resize(image, (OUTPUT_SHAPE[1], OUTPUT_SHAPE[0]), 3) image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) pred_inputs = np.zeros([1, OUTPUT_SHAPE[1], OUTPUT_SHAPE[0]]) pred_inputs[0, :] = np.transpose(image.reshape((OUTPUT_SHAPE[0], OUTPUT_SHAPE[1]))) pred_seq_len = np.ones(1) * OUTPUT_SHAPE[1] # 模型預測 pred_feed = {inputs: pred_inputs, seq_len: pred_seq_len} dd, log_probs = sess.run([decoded[0], log_prob], pred_feed) # 識別結果轉換 detected_list = decode_sparse_tensor(dd)[0] detected_text = '' for d in detected_list: detected_text = detected_text + d result_dict[i+1] = detected_text return result_dict # 返回結果字典
基於PyQt5的界面展示:用一個攝像頭不斷采集圖像,將采集到的圖像和預測結果傳給基於PyQt5搭建的界面,代碼如下:
demo.py(這是PyQt5的界面代碼)
# -*- coding: utf-8 -*- # Form implementation generated from reading ui file 'demo.ui' # # Created by: PyQt5 UI code generator 5.9.2 # # WARNING! All changes made in this file will be lost! from PyQt5 import QtCore, QtGui, QtWidgets class Ui_mainWindow(object): def setupUi(self, mainWindow): mainWindow.setObjectName("mainWindow") mainWindow.resize(1920, 1080) palette1 = QtGui.QPalette() palette1.setBrush(self.backgroundRole(), QtGui.QBrush(QtGui.QPixmap('images/bg2.jpg'))) mainWindow.setPalette(palette1) self.centralwidget = QtWidgets.QWidget(mainWindow) self.centralwidget.setObjectName("centralwidget") self.graphicsView = PlotWidget(self.centralwidget) self.graphicsView.setGeometry(QtCore.QRect(160, 340, 250, 150)) brush = QtGui.QBrush(QtGui.QColor(255, 255, 255, 0)) brush.setStyle(QtCore.Qt.SolidPattern) self.graphicsView.setBackgroundBrush(brush) self.graphicsView.setObjectName("graphicsView") self.pushButton = QtWidgets.QPushButton(self.centralwidget) self.pushButton.setGeometry(QtCore.QRect(170, 90, 81, 41)) self.pushButton.setObjectName("pushButton") # palette = QtGui.QPalette() # brush = QtGui.QBrush(QtGui.QColor(255, 255, 0)) # brush.setStyle(QtCore.Qt.SolidPattern) # palette.setBrush(QtGui.QPalette.Active, QtGui.QPalette.ButtonText, brush) # brush = QtGui.QBrush(QtGui.QColor(255, 255, 0)) # brush.setStyle(QtCore.Qt.SolidPattern) # palette.setBrush(QtGui.QPalette.Inactive, QtGui.QPalette.ButtonText, brush) # brush = QtGui.QBrush(QtGui.QColor(120, 120, 120)) # brush.setStyle(QtCore.Qt.SolidPattern) # palette.setBrush(QtGui.QPalette.Disabled, QtGui.QPalette.ButtonText, brush) # self.pushButton.setPalette(palette) self.graphicsView_2 = PlotWidget(self.centralwidget) self.graphicsView_2.setGeometry(QtCore.QRect(1460, 340, 250, 150)) self.graphicsView_2.setObjectName("graphicsView_2") self.graphicsView_3 = PlotWidget(self.centralwidget) self.graphicsView_3.setGeometry(QtCore.QRect(160, 660, 250, 150)) self.graphicsView_3.setObjectName("graphicsView_3") self.graphicsView_4 = PlotWidget(self.centralwidget) self.graphicsView_4.setGeometry(QtCore.QRect(1460, 660, 251, 151)) self.graphicsView_4.setObjectName("graphicsView_4") self.imageLabel = QtWidgets.QLabel(self.centralwidget) self.imageLabel.setGeometry(QtCore.QRect(615, 340, 640, 480)) self.imageLabel.setAutoFillBackground(False) self.imageLabel.setFrameShape(QtWidgets.QFrame.Box) self.imageLabel.setText("") self.imageLabel.setObjectName("imageLabel") self.label = QtWidgets.QLabel(self.centralwidget) self.label.setGeometry(QtCore.QRect(660, 130, 891, 111)) font = QtGui.QFont() font.setFamily("Agency FB") font.setPointSize(20) font.setBold(True) font.setWeight(75) self.label.setFont(font) self.label.setObjectName("label") self.layoutWidget = QtWidgets.QWidget(self.centralwidget) self.layoutWidget.setGeometry(QtCore.QRect(170, 180, 231, 131)) self.layoutWidget.setObjectName("layoutWidget") self.gridLayout = QtWidgets.QGridLayout(self.layoutWidget) self.gridLayout.setContentsMargins(0, 0, 0, 0) self.gridLayout.setObjectName("gridLayout") self.label_7 = QtWidgets.QLabel(self.layoutWidget) self.label_7.setText("") self.label_7.setObjectName("label_7") self.gridLayout.addWidget(self.label_7, 2, 1, 1, 1) self.label_2 = QtWidgets.QLabel(self.layoutWidget) self.label_2.setObjectName("label_2") self.gridLayout.addWidget(self.label_2, 0, 0, 1, 1) self.label_5 = QtWidgets.QLabel(self.layoutWidget) self.label_5.setText("") self.label_5.setObjectName("label_5") self.gridLayout.addWidget(self.label_5, 0, 1, 1, 1) self.label_4 = QtWidgets.QLabel(self.layoutWidget) self.label_4.setObjectName("label_4") self.gridLayout.addWidget(self.label_4, 2, 0, 1, 1) self.label_6 = QtWidgets.QLabel(self.layoutWidget) self.label_6.setText("") self.label_6.setObjectName("label_6") self.gridLayout.addWidget(self.label_6, 1, 1, 1, 1) self.label_3 = QtWidgets.QLabel(self.layoutWidget) self.label_3.setObjectName("label_3") self.gridLayout.addWidget(self.label_3, 1, 0, 1, 1) mainWindow.setCentralWidget(self.centralwidget) self.menubar = QtWidgets.QMenuBar(mainWindow) self.menubar.setGeometry(QtCore.QRect(0, 0, 1920, 30)) self.menubar.setObjectName("menubar") mainWindow.setMenuBar(self.menubar) self.statusbar = QtWidgets.QStatusBar(mainWindow) self.statusbar.setObjectName("statusbar") mainWindow.setStatusBar(self.statusbar) self.retranslateUi(mainWindow) QtCore.QMetaObject.connectSlotsByName(mainWindow) self.label_2.setVisible(False) self.label_3.setVisible(False) self.label_4.setVisible(False) self.label_5.setVisible(False) self.label_6.setVisible(False) self.label_7.setVisible(False) def retranslateUi(self, mainWindow): _translate = QtCore.QCoreApplication.translate mainWindow.setWindowTitle(_translate("mainWindow", "數字儀表智能采集演示系統")) self.pushButton.setText(_translate("mainWindow", "開始采集")) self.label.setText(_translate("mainWindow", "<!DOCTYPE HTML PUBLIC \"-//W3C//DTD HTML 4.0//EN\" \"http://www.w3.org/TR/REC-html40/strict.dtd\">\n" "<html><head><meta name=\"qrichtext\" content=\"1\" /><style type=\"text/css\">\n" "p, li { white-space: pre-wrap; }\n" "</style></head><body style=\" font-family:\'SimSun\'; font-size:9pt; font-weight:400; font-style:normal;\">\n" "<p style=\" margin-top:12px; margin-bottom:12px; margin-left:0px; margin-right:0px; -qt-block-indent:0; text-indent:0px;\"><span style=\" font-size:36pt; font-weight:600; color:#FFFF00;\">數字儀表智能采集演示系統</span></p></body></html>")) self.label_2.setText(_translate("mainWindow", "<html><head/><body><p><span style=\" color:#ffffff;\">1號燈</span></p></body></html>")) self.label_4.setText(_translate("mainWindow", "<html><head/><body><p><span style=\" color:#ffffff;\">3號燈</span></p></body></html>")) self.label_3.setText(_translate("mainWindow", "<html><head/><body><p><span style=\" color:#ffffff;\">2號燈</span></p></body></html>")) from pyqtgraph import PlotWidget
最終的運行代碼:
from PyQt5.QtWidgets import QApplication, QMainWindow from PyQt5.QtCore import pyqtSignal, QThread from PyQt5.QtGui import QImage, QPixmap import sys import array import cv2 from demo import Ui_mainWindow import numpy as np from imutils.video import WebcamVideoStream import readvc_box_03 as read_box from PIL import Image, ImageDraw, ImageFont import tensorflow as tf from main import load_model, predict # 創建一個圖標 from PyQt5.QtGui import QIcon import random import time model_dir = './tmp/train_data_model_13/' class MeterData(object): def __init__(self): self.saved_points_num = 40 self.meter_num = 4 self.meters = [array.array('d') for i in range(self.meter_num)] self.cur_numbers = [0. for i in range(self.meter_num)] def add_meters_number(self, numbers): if len(numbers) == self.meter_num: self.cur_numbers = numbers for i in range(self.meter_num): self.add_meter_number(self.cur_numbers[i], i) def add_meter_number(self, number, id): if len(self.meters[id]) < self.saved_points_num: self.meters[id].append(number) else: self.meters[id][:-1] = self.meters[id][1:] self.meters[id][-1] = number class ImageData(object): def __init__(self): self.pixmap = None def add_image(self, img): rgb_img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) rgb_img = cv2.resize(rgb_img, (640, 480), interpolation=cv2.INTER_CUBIC) q_img = self.get_qimage(rgb_img) self.pixmap = QPixmap.fromImage(q_img) def get_qimage(self, image): height, width, colors = image.shape bytesPerLine = 3 * width image = QImage(image.data, width, height, bytesPerLine, QImage.Format_RGB888) image = image.rgbSwapped() return image class LightData(object): def __init__(self): self.one = '' self.two = '' self.three = '' def set(self, red, orange, green): ''' 傳入三組顏色字典: {'r': [(133, 138, 28, 26)]} {'o': [(183, 139, 29, 28)]} {'g': [(235, 140, 27, 29)]} ''' self.one = self.get(red, orange, green)[0] self.two = self.get(red, orange, green)[1] self.three = self.get(red, orange, green)[2] def get(self, red, orange, green): a = sorted([red, orange, green], key=lambda x: list(x.values())[0][0][0]) color_list = [] for item in a: color_list.append(list(item.keys())[0]) return color_list meterData = MeterData() imageData = ImageData() lightData = LightData() class workThread(QThread): trigger = pyqtSignal() def __init__(self): super(workThread, self).__init__() def run(self): saver, inputs, seq_len, decoded, log_prob = load_model() sess = tf.Session() saver.restore(sess, tf.train.latest_checkpoint(model_dir)) print("[INFO] camera sensor warming up...") vs = WebcamVideoStream(src=0).start() print('camera ok') while True: cap = cv2.VideoCapture(0) cap.set(cv2.CAP_PROP_FOCUS, 160) cap.set(cv2.CAP_PROP_EXPOSURE , -8) cap.set(cv2.CAP_PROP_BACKLIGHT , 0) #########讀取攝像頭並且得到處理結果 frame = vs.read() # todo read_box在這里展開寫 # 讀取視頻的每一幀, 返回其梯度 gradient = read_box.img_preprocessing(frame) # print(gradient) # 得到每四張小圖片的坐標位置 list_box = read_box.get_box(gradient, num=4) # 根據坐標位置畫出矩形框 for i, bbx in enumerate(list_box): x, y, w, h = bbx draw_img = cv2.rectangle(frame, (x, y), (x + w, y + h), (0, 255, 0), 1) # 根據矩形框,將四張小圖片從原圖裁剪下來,得到圖片字典 # lb是排序后的坐標位置 try: dict_crap, lb = read_box.create_crapimg(frame, list_box) except: import traceback traceback.print_exc() if not lb: continue # 進行預測,返回預測數字 # sess作為變量,只初始化一次 t1 = time.time() predict_img_dict = predict(dict_crap, saver, inputs, seq_len, decoded, log_prob, sess) t2 = time.time() print(t2 - t1) # todo 預測數字出現非四位的,將錯誤圖片保存到本地 # if len(predict_img_dict[1]) != 4: # image = dict_crap[1 - 1] # predict_img_dict = predict(dict_crap, saver, inputs, seq_len, decoded, log_prob) # name = predict_img_dict[1] # cv2.imwrite('./error_imgs/0/' + str(name) + '_' + str(random.randint(0, 1000)) + '.jpg', image) # if len(predict_img_dict[2]) != 4: # image = dict_crap[2 - 1] # predict_img_dict = predict(dict_crap, saver, inputs, seq_len, decoded, log_prob) # name = predict_img_dict[2] # cv2.imwrite('./error_imgs/1/' + str(name) + '_' + str(random.randint(0, 1000)) + '.jpg', image) # if len(predict_img_dict[3]) != 4: # image = dict_crap[3 - 1] # predict_img_dict = predict(dict_crap, saver, inputs, seq_len, decoded, log_prob) # name = predict_img_dict[3] # cv2.imwrite('./error_imgs/2/' + str(name) + '_' + str(random.randint(0, 1000)) + '.jpg', image) # if len(predict_img_dict[4]) != 4: # image = dict_crap[4 - 1] # predict_img_dict = predict(dict_crap, saver, inputs, seq_len, decoded, log_prob) # name = predict_img_dict[4] # cv2.imwrite('./error_imgs/3/' + str(name) + '_' + str(random.randint(0, 1000)) + '.jpg', image) if predict_img_dict is not None: num_list = [int(num) if num else int(8888) for num in predict_img_dict.values()] # print(num_list) if len(num_list) >= 3: num_list[1], num_list[2] = num_list[2], num_list[1] if len(str(num_list[0])) != 4 or len(str(num_list[1])) != 4 or len(str(num_list[2])) != 4 or len( str(num_list[3])) != 4: continue for idx, box in enumerate(lb): # 將預測數字顯示在圖片上 showimg = cv2.putText(draw_img, str(predict_img_dict[idx + 1]), (int((box[0])), int(box[1] - 10)), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 255), 1, cv2.LINE_AA) else: pass ########## global meterData, imageData, lightData # todo 紅綠燈三色 # get_color(frame) ''' g [(237, 85, 31, 31)] r [(291, 85, 31, 31)] o [(345, 86, 32, 31)] ''' # 畫框,並得到坐標和顏色的三組函數 list_box_red = get_color_red(frame) list_box_orange = get_color_orange(frame) list_box_green = get_color_green(frame) if list_box_red["紅色"] != [] and list_box_orange["黃色"] != [] and list_box_green["綠色"] != []: # 不為空,再進行顏色判斷 lightData.set(list_box_red, list_box_orange, list_box_green) # todo 寫中文 # 加一個總的判斷,如果沒有讀入識別的數字和顏色,則顯示攝像頭拍到的任何內容 if draw_img is not None and frame is not None and gradient is not None \ and list_box is not None and showimg is not None: if list_box_red["紅色"] != [] and list_box_orange["黃色"] != [] and list_box_green["綠色"] != []: red_x, red_y = list_box_red["紅色"][0][0], list_box_red["紅色"][0][1] orange_x, orange_y = list_box_orange["黃色"][0][0], list_box_orange["黃色"][0][1] green_x, green_y = list_box_green["綠色"][0][0], list_box_green["綠色"][0][1] img_PIL = Image.fromarray(cv2.cvtColor(showimg, cv2.COLOR_BGR2RGB)) # 圖像從OpenCV格式轉換成PIL格式 font = ImageFont.truetype('font/simsun.ttc', 20) # 40為字體大小,根據需要調整 draw = ImageDraw.Draw(img_PIL) draw.text((red_x, red_y-20), '紅', font=font, fill=(255, 0, 0)) draw.text((orange_x, orange_y-20), '黃', font=font, fill=(255, 255, 0)) draw.text((green_x, green_y-20), '綠', font=font, fill=(0, 255, 0)) frame = cv2.cvtColor(np.asarray(img_PIL), cv2.COLOR_RGB2BGR) # 轉換回OpenCV格式 # 如果三色燈識別失敗 else: frame = showimg else: frame = frame # values = num_list image = frame image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) meterData.add_meters_number(values) imageData.add_image(image) # todo 三色燈 ### # lightData.set(list_box_red, list_box_orange, list_box_green) self.trigger.emit() sess.close() class MeterMainWindow(QMainWindow, Ui_mainWindow): updateSignal = pyqtSignal() def __init__(self, parent=None): super(MeterMainWindow, self).__init__(parent) self.setupUi(self) self.initUi() def initUi(self): self.curves = [self.graphicsView.plot(pen='y'), self.graphicsView_2.plot(pen='y'), self.graphicsView_3.plot(pen='y'), self.graphicsView_4.plot(pen='y')] # self.numbers = [self.lcdNumber, self.lcdNumber_2, self.lcdNumber_3, self.lcdNumber_4] self.workThread = workThread() self.pushButton.clicked.connect(self.start) def start(self): self.workThread.start() self.workThread.trigger.connect(self.update) def update(self): # todo 更新檢測結果 global meterData, imageData, lightData self.imageLabel.setPixmap(imageData.pixmap) # todo self.label_5.setText(lightData.one) self.label_6.setText(lightData.two) self.label_7.setText(lightData.three) for i, curve in enumerate(self.curves): curve.setData(meterData.meters[i]) def get_color_red(img): # 1.將RGB裝換成hsv色彩空間 hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV) # # 通道數是 3 # 2.定義數組,說明你要提取(過濾)的顏色目標, 紅色,橙色,綠色 # lower_hsv = np.array([0, 69, 194]) # upper_hsv = np.array([24, 141, 255]) lower_hsv = np.array([0, 84, 250]) upper_hsv = np.array([25, 111, 255]) # 3.進行過濾,提取,得到二值圖像 mask_ = cv2.inRange(hsv, lower_hsv, upper_hsv) # 通道數是 1 # 高斯濾波,去噪聲 gaussian = cv2.GaussianBlur(mask_, (3, 3), 1) # 腐蝕 kernel = np.ones((3, 3), np.uint8) erosion = cv2.erode(gaussian, kernel, iterations=1) # 膨脹 dige_dilate = cv2.dilate(erosion, kernel, iterations=4) # (形態學)梯度運算 = 膨脹運算 - 腐蝕運算 gradient = cv2.morphologyEx(dige_dilate, cv2.MORPH_GRADIENT, kernel) # 對前面得到的梯度,find其輪廓 contours, _ = cv2.findContours(gradient, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE) # 創建一個列表,用來保存矩形框坐標值 list_box = [] # 取前num個面積最大的 cnt = sorted(contours, key=cv2.contourArea, reverse=True)[:1] for c in cnt: # 得到坐標 x, y, w, h = cv2.boundingRect(c) # 存起來 list_box.append((x, y, w, h)) # 將矩形畫出來 draw_img = cv2.rectangle(img, (x, y), (x + w, y + h), (0, 255, 0), 1) list_box = list(set(list_box)) return {"紅色": list_box} def get_color_orange(img): # 1.將RGB裝換成hsv色彩空間 hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV) # # 通道數是 3 # 2.定義數組,說明你要提取(過濾)的顏色目標, 紅色,橙色,綠色 lower_hsv = np.array([15, 127, 147]) upper_hsv = np.array([30, 180, 255]) # 3.進行過濾,提取,得到二值圖像 mask_ = cv2.inRange(hsv, lower_hsv, upper_hsv) # 通道數是 1 # 高斯濾波,去噪聲 gaussian = cv2.GaussianBlur(mask_, (3, 3), 1) # 腐蝕 kernel = np.ones((3, 3), np.uint8) erosion = cv2.erode(gaussian, kernel, iterations=1) # 膨脹 dige_dilate = cv2.dilate(erosion, kernel, iterations=4) # (形態學)梯度運算 = 膨脹運算 - 腐蝕運算 gradient = cv2.morphologyEx(dige_dilate, cv2.MORPH_GRADIENT, kernel) # 對前面得到的梯度,find其輪廓 contours, _ = cv2.findContours(gradient, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE) # 創建一個列表,用來保存矩形框坐標值 list_box = [] # 取前num個面積最大的 cnt = sorted(contours, key=cv2.contourArea, reverse=True)[:1] for c in cnt: # 得到坐標 x, y, w, h = cv2.boundingRect(c) # 存起來 list_box.append((x, y, w, h)) # 將矩形畫出來 draw_img = cv2.rectangle(img, (x, y), (x + w, y + h), (0, 255, 0), 1) list_box = list(set(list_box)) return {"黃色": list_box} def get_color_green(img): # 1.將RGB裝換成hsv色彩空間 hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV) # # 通道數是 3 # 2.定義數組,說明你要提取(過濾)的顏色目標, 紅色,橙色,綠色 lower_hsv = np.array([72, 145, 149]) upper_hsv = np.array([89, 255, 255]) # 3.進行過濾,提取,得到二值圖像 mask_ = cv2.inRange(hsv, lower_hsv, upper_hsv) # 通道數是 1 # 高斯濾波,去噪聲 gaussian = cv2.GaussianBlur(mask_, (3, 3), 1) # 腐蝕 kernel = np.ones((3, 3), np.uint8) erosion = cv2.erode(gaussian, kernel, iterations=1) # 膨脹 dige_dilate = cv2.dilate(erosion, kernel, iterations=4) # (形態學)梯度運算 = 膨脹運算 - 腐蝕運算 gradient = cv2.morphologyEx(dige_dilate, cv2.MORPH_GRADIENT, kernel) # 對前面得到的梯度,find其輪廓 contours, _ = cv2.findContours(gradient, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE) # 創建一個列表,用來保存矩形框坐標值 list_box = [] # 取前num個面積最大的 cnt = sorted(contours, key=cv2.contourArea, reverse=True)[:1] for c in cnt: # 得到坐標 x, y, w, h = cv2.boundingRect(c) # 存起來 list_box.append((x, y, w, h)) # 將矩形畫出來 draw_img = cv2.rectangle(img, (x, y), (x + w, y + h), (0, 255, 0), 1) list_box = list(set(list_box)) return {"綠色": list_box} def main(): app = QApplication(sys.argv) # 給窗口設置一個圖標 app.setWindowIcon(QIcon('./images/mainimg.ico')) # 創建該類 main_Window = MeterMainWindow() main_Window.show() sys.exit(app.exec_()) if __name__ == '__main__': main()
注意:這里的def get_color_red(img): 寫了相似的三個函數,思想和提取紅色的數字是一樣的,做的工作就是對三色燈進行分類。
后續的優化思路:可以在lstm前加cnn,進行圖像的特征提取,將提取后的特征序列化后傳入lstm。
