整個項目源碼:GitHub
引言
前面我們講完交通標志的識別,現在我們開始嘗試來實現交通信號燈的識別
接下來我們將按照自己的思路來實現並完善整個Project.
在這個項目中,我們使用HSV色彩空間來識別交通燈,可以改善及提高的地方:
- 可以采用Faster-RCNN或SSD來實現交通燈的識別
首先我們第一步是導入數據,並在RGB及HSV色彩空間可視化部分數據。這里的數據,我們采用MIT自動駕駛課程的圖片,
總共三類:紅綠黃,1187張圖片,其中,723張紅色交通燈圖片,429張綠色交通燈圖片,35張黃色交通燈圖片。
導入庫
-
# import some libs
-
import cv2
-
import os
-
import glob
-
import random
-
import numpy
as np
-
import matplotlib.pyplot
as plt
-
import matplotlib.image
as mpimg
-
%matplotlib inline
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
-
# Image data directories
-
IMAGE
DIR_TRAINING = "traffic_light_images/training/"
-
IMAGE_DIR_TEST = "traffic_light_images/test/"
-
-
#load data
-
def load_dataset(image_dir):
-
'''
-
This function loads in images and their labels and places them in a list
-
image_dir:directions where images stored
-
'''
-
im_list =[]
-
image_types= ['red','yellow','green']
-
-
#Iterate through each color folder
-
for im_type in image_types:
-
file_lists = glob.glob(os.path.join(image_dir,im_type,'*'))
-
print(len(file_lists))
-
for file in file_lists:
-
im = mpimg.imread(file)
-
-
if not im is None:
-
im_list.append((im,im_type))
-
return im_list
-
IMAGE_LIST = load_dataset(IMAGE_DIR_TRAINING)
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
-
723
-
35
-
429
- 1
- 2
- 3
- 4
Visualize the data
這里可視化主要實現:
- 顯示圖像
- 打印出圖片的大小
- 打印出圖片對應的標簽
-
,ax = plt.subplots(
1,
3,figsize=(
5,
2))
-
#red
-
img
red = IMAGE_LIST[0][0]
-
ax[0].imshow(img_red)
-
ax[0].annotate(IMAGE_LIST[0][1],xy=(2,5),color='blue',fontsize='10')
-
ax[0].axis('off')
-
ax[0].set_title(img_red.shape,fontsize=10)
-
#yellow
-
img_yellow = IMAGE_LIST[730][0]
-
ax[1].imshow(img_yellow)
-
ax[1].annotate(IMAGE_LIST[730][1],xy=(2,5),color='blue',fontsize='10')
-
ax[1].axis('off')
-
ax[1].set_title(img_yellow.shape,fontsize=10)
-
#green
-
img_green = IMAGE_LIST[800][0]
-
ax[2].imshow(img_green)
-
ax[2].annotate(IMAGE_LIST[800][1],xy=(2,5),color='blue',fontsize='10')
-
ax[2].axis('off')
-
ax[2].set_title(img_green.shape,fontsize=10)
-
plt.show()
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
PreProcess Data
在導入了上述數據后,接下來我們需要標准化輸入及輸出
Input
從上圖,我們可以看出,每張圖片的大小並不一樣,我們需要標准化輸入
將每張圖圖片的大小resize成相同的大小,
因為對於分類任務來說,我們需要
在每張圖片上應用相同的算法,因此標准化圖像尤其重要
Output
這里我們的標簽數據是類別數據:’red’,’yellow’,’green’,因此我們可以利用one_hot方法將類別數據轉換成數值數據
-
# 標准化輸入圖像,這里我們resize圖片大小為32x32x3,這里我們也可以對圖像進行裁剪、平移、旋轉
-
def standardize(image_list):
-
'''
-
This function takes a rgb image as input and return a standardized version
-
image_list: image and label
-
'''
-
standard_list = []
-
#Iterate through all the image-label pairs
-
for item
in image_list:
-
image = item[
0]
-
label = item[
1]
-
# Standardize the input
-
standardized_im = standardize_input(image)
-
# Standardize the output(one hot)
-
one_hot_label = one_hot_encode(label)
-
# Append the image , and it's one hot encoded label to the full ,processed list of image data
-
standard_list.append((standardized_im,one_hot_label))
-
return standard_list
-
-
def standardize_input(image):
-
#Resize all images to be 32x32x3
-
standard_im = cv2.resize(image,(
32,
32))
-
return standard_im
-
-
def one_hot_encode(label):
-
#return the correct encoded label.
-
'''
-
# one_hot_encode("red") should return: [1, 0, 0]
-
# one_hot_encode("yellow") should return: [0, 1, 0]
-
# one_hot_encode("green") should return: [0, 0, 1]
-
'''
-
if label==
'red':
-
return [
1,
0,
0]
-
elif label==
'yellow':
-
return [
0,
1,
0]
-
else:
-
return [
0,
0,
1]
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
- 35
- 36
- 37
- 38
Test your code
實現完了上述標准化代碼后,我們需要進一步確定我們的代碼是正確的,因此接下來我們可以實現一個函數來實現上述代碼功能的檢驗
用Python搭建自動化測試框架,我們需要組織用例以及測試執行,這里我們推薦Python的標准庫——unittest。
-
import unittest
-
from IPython.display
import Markdown,display
-
-
# Helper function for printing markdown text(text in color/bold/etc)
-
def printmd(string):
-
display(Markdown(string))
-
# Print a test falied message,given an error
-
def print_fail():
-
printmd(
'<span style=="color: red;">Test Failed</span>')
-
def print_pass():
-
printmd(
'<span style="color:green;">Test Passed</span>')
-
# A class holding all tests
-
class Tests(unittest.TestCase):
-
#Tests the 'one_hot_encode' function,which is passed in as an argument
-
def test_one_hot(self,one_hot_function):
-
#test that the generate onr-hot lables match the expected one-hot label
-
#for all three cases(red,yellow,green)
-
try:
-
self.assertEqual([
1,
0,
0],one_hot_function(
'red'))
-
self.assertEqual([
0,
1,
0],one_hot_function(
'yellow'))
-
self.assertEqual([
0,
0,
1],one_hot_function(
'green'))
-
#enter exception
-
except self.failureException
as e:
-
#print out an error message
-
print_fail()
-
print(
'Your function did not return the excepted one-hot label')
-
print(
'\n'+str(e))
-
return
-
print_pass()
-
#Test if ay misclassified images are red but mistakenly classifed as green
-
def test_red_aa_green(self,misclassified_images):
-
#Loop through each misclassified image and the labels
-
for im,predicted_label,true_label
in misclassified_images:
-
#check if the iamge is one of a red light
-
if(true_label==[
1,
0,
0]):
-
try:
-
self.assertNotEqual(true_label,[
0,
1,
0])
-
except self.failureException
as e:
-
print_fail()
-
print(
'Warning:A red light is classified as green.')
-
print(
'\n'+str(e))
-
return
-
print_pass()
-
tests = Tests()
-
tests.test_one_hot(one_hot_encode)
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
- 35
- 36
- 37
- 38
- 39
- 40
- 41
- 42
- 43
- 44
- 45
Test Passed
Standardized_Train_List = standardize(IMAGE_LIST)
- 1
Feature Extraction
在這里我們將使用色彩空間、形狀分析及特征構造
RGB to HSV
-
#Visualize
-
image_num =
0
-
test_im = Standardized_Train_List[image_num][
0]
-
test_label = Standardized_Train_List[image_num][
1]
-
#convert to hsv
-
hsv = cv2.cvtColor(test_im, cv2.COLOR_RGB2HSV)
-
# Print image label
-
print(
'Label [red, yellow, green]: ' + str(test_label))
-
h = hsv[:,:,
0]
-
s = hsv[:,:,
1]
-
v = hsv[:,:,
2]
-
# Plot the original image and the three channels
-
, ax = plt.subplots(
1,
4, figsize=(
20,
10))
-
ax[
0].set
title('Standardized image')
-
ax[0].imshow(test_im)
-
ax[1].set_title('H channel')
-
ax[1].imshow(h, cmap='gray')
-
ax[2].set_title('S channel')
-
ax[2].imshow(s, cmap='gray')
-
ax[3].set_title('V channel')
-
ax[3].imshow(v, cmap='gray')
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
-
Label
[red, yellow, green]:
[1, 0, 0]
-
-
-
-
-
-
<
matplotlib
.image
.AxesImage
at 0
x7fb49ad71f28>
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
-
# create feature
-
'''
-
HSV即色相、飽和度、明度(英語:Hue, Saturation, Value),又稱HSB,其中B即英語:Brightness。
-
-
色相(H)是色彩的基本屬性,就是平常所說的顏色名稱,如紅色、黃色等。
-
飽和度(S)是指色彩的純度,越高色彩越純,低則逐漸變灰,取0-100%的數值。
-
明度(V),亮度(L),取0-100%。
-
-
'''
-
def create_feature(rgb_image):
-
'''
-
Basic brightness feature
-
rgb_image : a rgb_image
-
'''
-
hsv = cv2.cvtColor(rgb_image,cv2.COLOR_RGB2HSV)
-
-
sum_brightness = np.sum(hsv[:,:,
2])
-
area =
32
32
-
avg_brightness = sum_brightness / area
#Find the average
-
return avg_brightness
-
-
def high_saturation_pixels(rgb_image,threshold=80):
-
'''
-
Returns average red and green content from high saturation pixels
-
Usually, the traffic light contained the highest saturation pixels in the image.
-
The threshold was experimentally determined to be 80
-
'''
-
high_sat_pixels = []
-
hsv = cv2.cvtColor(rgb,cv2.COLOR_RGB2HSV)
-
for i
in range(
32):
-
for j
in range(
32):
-
if hsv[i][j][
1] > threshold:
-
high_sat_pixels.append(rgb_image[i][j])
-
if
not high_sat_pixels:
-
return highest_sat_pixel(rgb_image)
-
-
sum_red =
0
-
sum_green =
0
-
for pixel
in high_sat_pixels:
-
sum_red+=pixel[
0]
-
sum_green+=pixel[
1]
-
-
# use sum() instead of manually adding them up
-
avg_red = sum_red / len(high_sat_pixels)
-
avg_green = sum_green / len(high_sat_pixels)
0.8
-
return avg_red,avg_green
-
def highest_sat_pixel(rgb_image):
-
'''
-
Finds the highest saturation pixels, and checks if it has a higher green
-
or a higher red content
-
'''
-
hsv = cv2.cvtColor(rgb_image,cv2.COLOR_RGB2HSV)
-
s = hsv[:,:,
1]
-
-
x,y = (np.unravel_index(np.argmax(s),s.shape))
-
if rgb_image[x,y,
0] > rgb_image[x,y,
1]*
0.9:
-
return
1,
0
#red has a higher content
-
return
0,
1
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
- 35
- 36
- 37
- 38
- 39
- 40
- 41
- 42
- 43
- 44
- 45
- 46
- 47
- 48
- 49
- 50
- 51
- 52
- 53
- 54
- 55
- 56
- 57
- 58
Test dataset
接下來我們導入測試集來看看,上述方法的測試精度
上述方法我們實現了:
1.求平均的brightness
2.求red及green的色彩飽和度
有人或許會提出疑問,為啥沒有進行yellow的判斷,因此我們作出以下的改善
reference url
這里部分閾值,我們直接參考WIKI上的數據:
-
def estimate_label(rgb_image,display=False):
-
'''
-
rgb_image:Standardized RGB image
-
'''
-
return red_green_yellow(rgb_image,display)
-
def findNoneZero(rgb_image):
-
rows,cols,
-
= rgb
image.shape
-
counter = 0
-
for row in range(rows):
-
for col in range(cols):
-
pixels = rgb_image[row,col]
-
if sum(pixels)!=0:
-
counter = counter+1
-
return counter
-
def red_green_yellow(rgb_image,display):
-
'''
-
Determines the red , green and yellow content in each image using HSV and experimentally
-
determined thresholds. Returns a Classification based on the values
-
'''
-
hsv = cv2.cvtColor(rgb_image,cv2.COLOR_RGB2HSV)
-
sum_saturation = np.sum(hsv[:,:,1])# Sum the brightness values
-
area = 3232
-
avg_saturation = sum_saturation / area #find average
-
-
sat_low = int(avg_saturation1.3)#均值的1.3倍,工程經驗
-
val_low = 140
-
#Green
-
lower_green = np.array([70,sat_low,val_low])
-
upper_green = np.array([100,255,255])
-
green_mask = cv2.inRange(hsv,lower_green,upper_green)
-
green_result = cv2.bitwise_and(rgb_image,rgb_image,mask = green_mask)
-
#Yellow
-
lower_yellow = np.array([10,sat_low,val_low])
-
upper_yellow = np.array([60,255,255])
-
yellow_mask = cv2.inRange(hsv,lower_yellow,upper_yellow)
-
yellow_result = cv2.bitwise_and(rgb_image,rgb_image,mask=yellow_mask)
-
-
# Red
-
lower_red = np.array([150,sat_low,val_low])
-
upper_red = np.array([180,255,255])
-
red_mask = cv2.inRange(hsv,lower_red,upper_red)
-
red_result = cv2.bitwise_and(rgb_image,rgb_image,mask = red_mask)
-
if display==True:
-
,ax = plt.subplots(
1,
5,figsize=(
20,
10))
-
ax[
0].set_title(
'rgb image')
-
ax[
0].imshow(rgb_image)
-
ax[
1].set_title(
'red result')
-
ax[
1].imshow(red_result)
-
ax[
2].set_title(
'yellow result')
-
ax[
2].imshow(yellow_result)
-
ax[
3].set_title(
'green result')
-
ax[
3].imshow(green_result)
-
ax[
4].set_title(
'hsv image')
-
ax[
4].imshow(hsv)
-
plt.show()
-
sum_green = findNoneZero(green_result)
-
sum_red = findNoneZero(red_result)
-
sum_yellow = findNoneZero(yellow_result)
-
if sum_red >= sum_yellow
and sum_red>=sum_green:
-
return [
1,
0,
0]
#Red
-
if sum_yellow>=sum_green:
-
return [
0,
1,
0]
#yellow
-
return [
0,
0,
1]
#green
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
- 35
- 36
- 37
- 38
- 39
- 40
- 41
- 42
- 43
- 44
- 45
- 46
- 47
- 48
- 49
- 50
- 51
- 52
- 53
- 54
- 55
- 56
- 57
- 58
- 59
- 60
- 61
- 62
- 63
Test
接下來我們選擇三張圖片來看看測試效果
img_red,img_yellow,img_green
-
img_test = [(img_red,
'red'),(img_yellow,
'yellow'),(img_green,
'green')]
-
standardtest = standardize(img_test)
-
-
for img
in standardtest:
-
predicted_label = estimate_label(img[
0],display =
True)
-
print(
'Predict label :',predicted_label)
-
print(
'True label:',img[
1])
- 1
- 2
- 3
- 4
- 5
- 6
- 7
-
Predict
label :
[1, 0, 0]
-
True
label:
[1, 0, 0]
- 1
- 2
- 3
-
Predict
label :
[0, 1, 0]
-
True
label:
[0, 1, 0]
- 1
- 2
- 3
-
Predict
label :
[0, 0, 1]
-
True
label:
[0, 0, 1]
- 1
- 2
- 3
-
# Using the load_dataset function in helpers.py
-
# Load test data
-
TEST_IMAGE_LIST = load_dataset(IMAGE_DIR_TEST)
-
-
# Standardize the test data
-
STANDARDIZED_TEST_LIST = standardize(TEST_IMAGE_LIST)
-
-
# Shuffle the standardized test data
-
random.shuffle(STANDARDIZED_TEST_LIST)
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
-
181
-
9
-
107
- 1
- 2
- 3
- 4
Determine the Accuracy
接下來我們來看看咱們算法在測試集上的准確率。下面我們實現的代碼存儲所有的被錯分的圖片以及它們被預測的結果及真實標簽。
這些數據被存儲在MISCLASSIFIED.
-
# COnstructs a list of misclassfied iamges given a list of test images and their labels
-
# This will throw an assertionerror if labels are not standardized(one hot encode)
-
def get_misclassified_images(test_images,display=False):
-
misclassified_images_labels = []
-
#Iterate through all the test images
-
#Classify each image and compare to the true label
-
for image
in test_images:
-
# Get true data
-
im = image[
0]
-
true_label = image[
1]
-
assert (len(true_label)==
3),
'This true_label is not the excepted length (3).'
-
-
#Get predicted label from your classifier
-
predicted_label = estimate_label(im,display=
False)
-
assert(len(predicted_label)==
3),
'This predicted_label is not the excepted length (3).'
-
-
#compare true and predicted labels
-
if(predicted_label!=true_label):
-
#if these labels are ot equal, the image has been misclassified
-
misclassified_images_labels.append((im,predicted_label,true_label))
-
# return the list of misclassified [image,predicted_label,true_label] values
-
return misclassified_images_labels
-
# Find all misclassified images in a given test set
-
MISCLASSIFIED = get_misclassified_images(STANDARDIZED_TEST_LIST,display=
False)
-
#Accuracy calcuations
-
total = len(STANDARDIZED_TEST_LIST)
-
num_correct = total-len(MISCLASSIFIED)
-
accuracy = num_correct / total
-
print(
'Accuracy:'+str(accuracy))
-
print(
'Number of misclassfied images = '+str(len(MISCLASSIFIED))+
' out of '+str(total))
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
-
Accuracy:
0.9797979797979798
-
Number
of misclassfied images =
6 out
of
297
- 1
- 2
- 3