Python Opencv-contrib Camshift&kalman卡尔曼滤波&CSRT算法 目标跟踪实现


本文为原创文章,转载请注明出处。

本次课题实现目标跟踪一共用到了三个算法,分别是Camshift、Kalman、CSRT,基于Python语言的Tkinter模块实现GUI与接口设计,项目一共包含三个文件:

main.py:

  1 # coding:utf-8
  2 # 主模块
  3 
  4 
  5 import Tkinter
  6 import tkFileDialog
  7 import cv2
  8 import time
  9 from PIL import ImageTk
 10 # 导入自定义模块
 11 import track
 12 import utils
 13 
 14 
 15 # 设置窗口800*480
 16 root = Tkinter.Tk()
 17 root.title("基于视频的实时行人追踪")
 18 root.geometry("800x480")
 19 
 20 # 设置背景
 21 canvas = Tkinter.Canvas(root, width=800, height=480, highlightthickness=0, borderwidth=0)
 22 background_image = ImageTk.PhotoImage(file="background.jpg")  # 项目本地路径(背景图片)
 23 canvas.create_image(0, 0, anchor="nw", image=background_image)
 24 canvas.pack()
 25 
 26 # 显示提示
 27 label_a = Tkinter.Label(root, text="基于视频的实时行人追踪", font=("KaiTi", 20), height=2)
 28 label_a.pack()
 29 canvas.create_window(400, 100, height=25, window=label_a)
 30 
 31 # 显示路径
 32 show_path = Tkinter.StringVar()
 33 show_path.set("请选择一个文件夹")
 34 
 35 # 显示路径标签
 36 label_b = Tkinter.Label(root, textvariable=show_path, font=("Times New Roman", 15), height=2)
 37 label_b.pack()
 38 canvas.create_window(400, 150, window=label_b)
 39 
 40 # 坐标库
 41 ROI = utils.ROI()
 42 # 路径库
 43 path = utils.Path()
 44 
 45 
 46 # 选择序列
 47 def hit_button_a():
 48     path.init(tkFileDialog.askdirectory(title="Select Folder"))
 49     # 显示路径
 50     if path.img_path != "":
 51         show_path.set("文件路径:" + str(path.img_path)[:-1] + "\n序列总数:" + str(path.sum))
 52     else:
 53         show_path.set("路径错误!")
 54 
 55 
 56 button_a = Tkinter.Button(root, text="选择序列", font=("KaiTi", 15), height=2, command=hit_button_a)
 57 button_a.pack()
 58 canvas.create_window(400, 200, height=20, window=button_a)
 59 
 60 
 61 # ROI
 62 def hit_button_b():
 63     # 读取首帧图像
 64     first_image = cv2.imread(path.pics_list[0])
 65     # ROI
 66     ROI.init_window(cv2.selectROI(windowName="ROI", img=first_image, showCrosshair=True, fromCenter=False))
 67     cv2.destroyAllWindows()
 68 
 69 
 70 button_b = Tkinter.Button(root, text="标记目标", font=("KaiTi", 15), heigh=2, command=hit_button_b)
 71 button_b.pack()
 72 canvas.create_window(400, 250, height=20, window=button_b)
 73 
 74 
 75 # 目标追踪
 76 
 77 def hit_button_c():
 78     global camshift, kcf, csrt
 79     index = utils.index(path.groundtruth_path)  # 读取真值
 80     firstframe = True
 81     kalman_xy = track.KalmanFilter()
 82     kalman_size = track.KalmanFilter()
 83     bbox = [0, 0, 0, 0]
 84 
 85     for i in range(0, path.sum):
 86         start = time.time()  # 开始计时
 87         frame = cv2.imread(path.pics_list[i])  # 读取
 88         if firstframe:
 89             camshift = track.Camshift(frame, ROI.window)
 90             kcf = track.KCFtracker(frame, ROI.window)
 91             firstframe = False
 92             continue
 93         # camshift.update(frame)
 94         ok = kcf.update(frame)
 95         if not ok:
 96             mes = (bbox[0], bbox[1], bbox[2], bbox[3])
 97             print mes
 98             kcf.tracker.init(frame, mes)
 99             ok = kcf.update(frame)
100         end = time.time()  # 结束计时
101         seconds = end - start  # 处理用时
102         groundtruth = index.groundtruth(i)  # 真值
103         window = camshift.window
104         window = kcf.window
105 
106         A = window[0] - ROI.window[0]
107         B = window[1] - ROI.window[1]
108         C = window[2] - ROI.window[2]
109         D = window[3] - ROI.window[3]
110         xy = kalman_xy.predict(A, B)
111         size = kalman_size.predict(C, D)  # 卡尔曼滤波
112 
113         bbox[0] = int(ROI.window[0] + xy[0])
114         bbox[1] = int(ROI.window[1] + xy[1])
115         bbox[2] = int(ROI.window[2] + size[0])
116         bbox[3] = int(ROI.window[3] + size[1])
117 
118         ape = index.APE(bbox, groundtruth)  # 像素误差
119         aor = index.AOR(bbox, groundtruth)  # 重叠率
120         # 绘制数据曲线
121         # eva.draw(FPS, ape, aor, i)
122         frame = utils.display(seconds, frame, bbox, ape, aor, groundtruth, truth=False)  # 跟踪框
123         # 显示
124         cv2.imshow("Track", frame)
125         t = cv2.waitKey(20) & 0xff
126         # 按空格键停止
127         if t == ord(" "):
128             cv2.waitKey(0)
129         # 按ESC键退出
130         if t == 27:
131             cv2.destroyAllWindows()
132             break
133     cv2.destroyAllWindows()
134     print ("跟踪结束!\n")
135 
136 
137 button_c = Tkinter.Button(root, text="开始追踪", font=("KaiTi", 15), heigh=2, command=hit_button_c)
138 button_c.pack()
139 canvas.create_window(400, 300, height=20, window=button_c)
140 
141 
142 root.mainloop()

自定义跟踪器模块track.py:

 1 # coding:utf-8
 2 # 追踪器模块
 3 
 4 
 5 import cv2
 6 import numpy as np
 7 
 8 
 9 # 得到中心点
10 def center(points):
11     x = (points[0][0] + points[1][0] + points[2][0] + points[3][0]) / 4
12     y = (points[0][1] + points[1][1] + points[2][1] + points[3][1]) / 4
13     return np.array([np.float32(x), np.float32(y)], np.float32)
14 
15 
16 class Camshift:
17     def __init__(self, frame, ROI):
18         x, y, w, h = ROI
19         self.window = ROI
20         roi = frame[y:y + h, x:x + w]  # ROI裁剪
21         hsv_roi = cv2.cvtColor(roi, cv2.COLOR_BGR2HSV)  # HSV转换
22         mask = cv2.inRange(hsv_roi, np.array((0., 60., 32.)), np.array((180., 255., 255.)))  # 设置阈值
23         self.hist = cv2.calcHist([hsv_roi], [0], mask, [180], [0, 180])  # 直方图
24         cv2.normalize(self.hist, self.hist, 0, 255, cv2.NORM_MINMAX)  # 归一化
25 
26     def update(self, frame):
27         term_crit = (cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT, 1, 10)  # 迭代终止标准(最多十次迭代)
28         hsv = cv2.cvtColor(frame, cv2.COLOR_BGR2HSV)  # HSV转换
29         dst = cv2.calcBackProject([hsv], [0], self.hist, [0, 180], 1)  # 反向投影
30         # cv2.imshow("dst", dst)
31         # cv2.waitKey(10)
32         x, y, w, h = self.window  # 跟踪框
33         ret, (x, y, w, h) = cv2.CamShift(dst, (x, y, w, h), term_crit)
34         self.window = (x, y, w, h)
35 
36 
37 class MILtracker:
38     def __init__(self, frame, ROI):
39         self.window = ROI
40         self.tracker = cv2.TrackerMIL_create()
41         self.tracker.init(frame, self.window)
42 
43     def update(self, frame):
44         ok, self.window = self.tracker.update(frame)
45 
46 
47 class KCFtracker:
48     def __init__(self, frame, ROI):
49         self.window = ROI
50         self.tracker = cv2.TrackerKCF_create()
51         self.tracker = cv2.TrackerCSRT_create()
52         self.tracker.init(frame, self.window)
53 
54     def update(self, frame):
55         ok, self.window = self.tracker.update(frame)
56         return ok
57 
58 class CSRTtracker:
59     def __init__(self, frame, ROI):
60         self.window = ROI
61         self.tracker = cv2.TrackerCSRT_create()
62         self.tracker.init(frame, self.window)
63 
64     def update(self, frame):
65         ok, self.window = self.tracker.update(frame)
66 
67 
68 class KalmanFilter:
69     def __init__(self):
70         self.kalman = cv2.KalmanFilter(4, 2)
71         self.kalman.measurementMatrix = np.array([[1, 0, 0, 0], [0, 1, 0, 0]], np.float32)
72         self.kalman.transitionMatrix = np.array([[1, 0, 1, 0], [0, 1, 0, 1], [0, 0, 1, 0], [0, 0, 0, 1]], np.float32)
73         self.kalman.processNoiseCov = np.array([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]],
74                                                np.float32) * 0.003
75         self.kalman.measurementNoiseCov = np.array([[1, 0], [0, 1]], np.float32) * 0.01
76 
77     def predict(self, x, y):
78         current_mes = np.array([[np.float32(x)], [np.float32(y)]])
79         self.kalman.correct(current_mes)
80         current_pre = self.kalman.predict()
81         return current_pre

自定义的工具模块utils.py:

  1 # coding:utf-8
  2 # 工具模块
  3 
  4 
  5 import cv2
  6 import os
  7 import re
  8 
  9 
 10 # 目录存储模块
 11 class Path:
 12     # 存储文件路径
 13     def __init__(self):
 14         self.img_path = ""
 15         self.groundtruth_path = ""
 16         # 目录
 17         self.inpics_list = []
 18         # 绝对路径目录
 19         self.pics_list = []
 20         self.sum = 0
 21 
 22     # 初始化文件路径
 23     def init(self, path):
 24         # 请选择包含img和groundtruth的总文件夹
 25         if path != '':
 26             self.img_path = path + "/img/"
 27             self.groundtruth_path = path + "/groundtruth.txt"
 28             self.inpics_list = os.listdir(self.img_path)
 29             self.inpics_list.sort()
 30         # 目录统计
 31         self.sum = len(self.inpics_list)
 32         # 绝对路径
 33         self.pics_list = [self.img_path + x for x in self.inpics_list]
 34 
 35 
 36 # 坐标储存模块
 37 class ROI:
 38     # 存储坐标
 39     def __init__(self):
 40         self.x = 0
 41         self.y = 0
 42         self.width = 0
 43         self.height = 0
 44         self.window = []
 45 
 46     # 单坐标的初始化
 47     def init(self, x, y, width, height):
 48         self.x = x
 49         self.y = y
 50         self.width = width
 51         self.height = height
 52         self.window = (x, y, width, height)
 53 
 54     # 窗口坐标的初始化
 55     def init_window(self, window):
 56         self.x = window[0]
 57         self.y = window[1]
 58         self.width = window[2]
 59         self.height = window[3]
 60         self.window = (window[0], window[1], window[2], window[3])
 61 
 62 
 63 # 评价指标模块
 64 class index:
 65     def __init__(self, path):
 66         self.fps = []
 67         self.ape = []
 68         self.aor = []
 69         self.n = []
 70         # 载入真值
 71         self.lines = open(path).readlines()
 72 
 73     # 得到真值
 74     def groundtruth(self, i):
 75         line = [x for x in self.lines]
 76         # 切割
 77         window = [0, 0, 0, 0]
 78         for n in range(0, 4):
 79             window[n] = int(re.split("[,\n\t ]", line[i])[n])
 80         return window
 81 
 82     # 像素误差
 83     @staticmethod
 84     def APE(window, bbox):
 85         x1, y1, w1, h1 = window
 86         x2, y2, w2, h2 = bbox
 87         # 跟踪框中心
 88         center = [int(x1 + 1 / 2 * w1), int(y1 + 1 / 2 * h1)]
 89         # 真值中心
 90         truth_center = [int(x2 + 1 / 2 * w2), int(y2 + 1 / 2 * h2)]
 91         # 计算像素误差
 92         ape = pow(pow(center[0] - truth_center[0], 2) + pow(center[1] - truth_center[1], 2), .2)
 93         ape = round(ape, 2)
 94         return ape
 95 
 96     # 重叠率
 97     @staticmethod
 98     def AOR(window, bbox):
 99         x1, y1, w1, h1 = window
100         x2, y2, w2, h2 = bbox
101         col = min(x1 + w1, x2 + w2) - max(x1, x2)
102         row = min(y1 + h1, y2 + h2) - max(y1, y2)
103         intersection = col * row
104         area1 = w1 * h1
105         area2 = w2 * h2
106         coincide = intersection * 1.0 / (area1 + area2 - intersection) * 100
107         aor = round(coincide, 2)
108         return aor
109 
110     # 绘制数据曲线
111     def draw(self, fps, ape, aor, number):
112         self.fps.append(fps)
113         self.ape.append(ape)
114         self.aor.append(aor)
115         self.n.append(number)
116 
117 
118 # 跟踪框显示模块
119 def display(seconds, img, window, ape, aor, groundtruth, truth=False):
120     window = [int(x) for x in window]
121     x, y, w, h = window
122     # 跟踪框
123     img = cv2.rectangle(img, (x, y), (x + w, y + h), (255, 0, 0), 2)
124     if truth:
125         a, b, c, d = groundtruth
126         img = cv2.rectangle(img, (a, b), (a + c, b + d), (0, 255, 0), 2)
127     # 中心点
128     xc = (x + w / 2)
129     yc = (y + h / 2)
130     cv2.circle(img, (xc, yc), 3, (255, 0, 0), -1)
131     # 坐标
132     text = cv2.FONT_HERSHEY_COMPLEX_SMALL
133     size = 1
134     # text = cv2.FONT_ITALIC
135     cv2.putText(img, ('X=' + str(xc)), (10, 20), text, size, (0, 0, 255), 1, cv2.LINE_AA)
136     cv2.putText(img, ('Y=' + str(yc)), (10, 50), text, size, (0, 0, 255), 1, cv2.LINE_AA)
137     # FPS
138     fps = 1 / seconds
139     cv2.putText(img, ('FPS = ' + str(int(fps))), (10, 80), text, size, (0, 255, 0), 1, cv2.LINE_AA)
140     cv2.putText(img, ('APE = ' + str(ape)) + 'pixels', (10, 110), text, size, (0, 255, 255), 1, cv2.LINE_AA)
141     cv2.putText(img, ('AOR = ' + str(aor) + '%'), (10, 140), text, size, (255, 0, 255), 1, cv2.LINE_AA)
142     return img
143 
144 def dis(window, img):
145     window = [int(x) for x in window]
146     x, y, w, h = window
147     img = cv2.rectangle(img, (x, y), (x + w, y + h), (255, 0, 0), 2)
148     return img

注:

1.在项目目录下保存一张GUI界面的背景图像background.jpg。

2.在选择样本序列时,格式为:所选定文件夹包含子文件夹img,保存有0001.jpg~...的所有序列,子文件groundtruth.txt真值文件。

3.务必使用低版本(未知原因)的Opencv-contrib,否则不能使用CSRT跟踪器。

TBD.


免责声明!

本站转载的文章为个人学习借鉴使用,本站对版权不负任何法律责任。如果侵犯了您的隐私权益,请联系本站邮箱yoyou2525@163.com删除。



 
粤ICP备18138465号  © 2018-2025 CODEPRJ.COM