本文主要是用kNN算法對字母圖片進行特征提取,分類識別。內容如下:
- kNN算法及相關Python模塊介紹
- 對字母圖片進行特征提取
- kNN算法實現
- kNN算法分析
一、kNN算法介紹
K近鄰(kNN,k-NearestNeighbor)分類算法是機器學習算法中最簡單的方法之一。所謂K近鄰,就是k個最近的鄰居的意思,說的是每個樣本都可以用它最接近的k個鄰居來代表。我們將樣本分為訓練樣本和測試樣本。對一個測試樣本 t 進行分類,kNN的做法是先計算樣本 t 到所有訓練樣本的歐氏距離,然后從中找出k個距離最短的訓練樣本,用這k個訓練樣本中出現次數最多的類別表示樣本 t 的類別。
歐式距離的計算公式:
假設每個樣本有兩個特征值,如 A :(a1,b1)B:(a2,b2) 則AB的歐式距離為
舉個例子:根據下圖前四位同學的成績和等級,預測第五位小白同學的等級。
我們可以看出:語文和數學成績是一個學生的特征,等級是一個學生的類別。
前四位同學是訓練樣本,第五位同學是測試樣本。我們現在用kNN算法來預測第五位同學的等級,k取3。
按照上面歐式距離公式我們可以計算
d(5-1)== 7 d(5-2)=
= 30
d(5-3)== 6 d(5-4)=
= 19.2
因為 k 取 3,所以我們尋找3個距離最近的樣本,即編號為3,1,4的同學,他們的等級分別是 B,B,A。 這三個樣本的分類中,出現了2次B,一次A,B出現次數最多,所以5號同學的等級可能為B
常用Python模塊
NumPy:NumPy是Python的一種開源的數值計算擴展。這種工具可用來存儲和處理大型矩陣,比Python自身的嵌套列表結構要高效的多。
PIL:Python Imaging Library,是Python平台事實上的圖像處理標准庫,功能非常強大,API也簡單易用。但PIL包主要針對Python2,不兼容Python3,所以在Python3中使用Pillow,后者是大牛根據PIL移植過來的,兩者用法相同。
上面兩個Python庫都可以通過pip進行安裝。
pip3 install [name]
還有就是Python 自帶標准庫:shutil模塊提供了大量的文件的高級操作,特別針對文件拷貝和刪除,主要功能為目錄和文件操作以及壓縮操作。operator模塊是Python 的運算符庫,os 模塊是Python的系統的和操作系統相關的函數庫。
二、對圖片進行特征提取
1、采集手寫字母的圖片素材
有許多提供機器學習數據集的網站,如知乎上的整理 https://www.zhihu.com/question/63383992/answer/222718972 我搜集到的手寫字母圖片資源如下 鏈接:https://pan.baidu.com/s/1pM329fl 密碼:i725 其中by_class.zip 壓縮包是已經分類好的圖片樣本,可以直接下載使用
2、提取圖片素材的特征
最簡單的做法是將圖片轉換為由0 和1 組成的txt 文件,如
轉換代碼如下:
1 import os 2 import shutil 3 from PIL import Image 4 5 6 # image_file_prefix png圖片所在的文件夾 7 # file_name png png圖片的名字 8 # txt_path_prefix 轉換后txt 文件所在的文件夾 9 def generate_txt_image(image_file_prefix, file_name, txt_path_prefix): 10 """將圖片處理成只有0 和 1 的txt 文件""" 11 # 將png圖片轉換成二值圖並截取四周多余空白部分 12 image_path = os.path.join(image_file_prefix, file_name) 13 # convert('L') 將圖片轉為灰度圖 convert('1') 將圖片轉為二值圖 14 img = Image.open(image_path, 'r').convert('1').crop((32, 32, 96, 96)) 15 # 指定轉換后的寬 高 16 width, height = 32, 32 17 img.thumbnail((width, height), Image.ANTIALIAS) 18 # 將二值圖片轉換為0 1,存儲到二位數組arr中 19 arr = [] 20 for i in range(width): 21 pixels = [] 22 for j in range(height): 23 pixel = int(img.getpixel((j, i))) 24 pixel = 0 if pixel == 0 else 1 25 pixels.append(pixel) 26 arr.append(pixels) 27 28 # 創建txt文件(mac下使用os.mknod()創建文件需要root權限,這里改用復制的方式) 29 text_image_file = os.path.join(txt_path_prefix, file_name.split('.')[0] + '.txt') 30 empty_txt_path = "/Users/beiyan/Downloads/empty.txt" 31 shutil.copyfile(empty_txt_path, text_image_file) 32 33 # 寫入文件 34 with open(text_image_file, 'w') as text_file_object: 35 for line in arr: 36 for e in line: 37 text_file_object.write(str(e)) 38 text_file_object.write("\n")
將所有素材轉換為 txt 后,分為兩部分:訓練樣本 和 測試樣本。
三、kNN算法實現
1、將txt文件轉為一維數組的方法:
1 def img2vector(filename, width, height): 2 """將txt文件轉為一維數組""" 3 return_vector = np.zeros((1, width * height)) 4 fr = open(filename) 5 for i in range(height): 6 line = fr.readline() 7 for j in range(width): 8 return_vector[0, height * i + j] = int(line[j]) 9 return return_vector
2、對測試樣本進行kNN分類,返回測試樣本的類別:
1 import numpy as np 2 import os 3 import operator 4 5 6 # test_set 單個測試樣本 7 # train_set 訓練樣本二維數組 8 # labels 訓練樣本對應的分類 9 # k k值 10 def classify(test_set, train_set, labels, k): 11 """對測試樣本進行kNN分類,返回測試樣本的類別""" 12 # 獲取訓練樣本條數 13 train_size = train_set.shape[0] 14 15 # 計算特征值的差值並求平方 16 # tile(A,(m,n)),功能是將數組A行重復m次 列重復n次 17 diff_mat = np.tile(test_set, (train_size, 1)) - train_set 18 sq_diff_mat = diff_mat ** 2 19 20 # 計算歐式距離 存儲到數組 distances 21 sq_distances = sq_diff_mat.sum(axis=1) 22 distances = sq_distances ** 0.5 23 24 # 按距離由小到大排序對索引進行排序 25 sorted_index = distances.argsort() 26 27 # 求距離最短k個樣本中 出現最多的分類 28 class_count = {} 29 for i in range(k): 30 near_label = labels[sorted_index[i]] 31 class_count[near_label] = class_count.get(near_label, 0) + 1 32 sorted_class_count = sorted(class_count.items(), key=operator.itemgetter(1), reverse=True) 33 return sorted_class_count[0][0]
3、統計分類錯誤率
1 # train_data_path 訓練樣本文件夾 2 # test_data_path 測試樣本文件夾 3 # k k個最近鄰居 4 def get_error_rate(train_data_path, test_data_path, k): 5 """統計識別錯誤率""" 6 width, height = 32, 32 7 train_labels = [] 8 9 training_file_list = os.listdir(train_data_path) 10 train_size = len(training_file_list) 11 12 # 生成全為0的訓練集數組 13 train_set = np.zeros((train_size, width * height)) 14 15 # 讀取訓練樣本 16 for i in range(train_size): 17 file = training_file_list[i] 18 file_name = file.split('.')[0] 19 label = str(file_name.split('_')[0]) 20 train_labels.append(label) 21 train_set[i, :] = img2vector(os.path.join(train_data_path, training_file_list[i]), width, height) 22 23 test_file_list = os.listdir(test_data_path) 24 # 識別錯誤的個數 25 error_count = 0.0 26 # 測試樣本的個數 27 test_count = len(test_file_list) 28 29 # 統計識別錯誤的個數 30 for i in range(test_count): 31 file = test_file_list[i] 32 true_label = file.split('.')[0].split('_')[0] 33 34 test_set = img2vector(os.path.join(test_data_path, test_file_list[i]), width, height) 35 test_label = classify(test_set, train_set, train_labels, k) 36 print(true_label, test_label) 37 if test_label != true_label: 38 error_count += 1.0 39 percent = error_count / float(test_count) 40 print("識別錯誤率是:{}".format(str(percent)))
上述完整代碼地址:https://gitee.com/beiyan/machine_learning/tree/master/knn
4、測試結果
訓練樣本: 0-9,a-z,A-Z 共62個字符,每個字符選取120個訓練樣本 , 一共有7440 個訓練樣本。每個字符選取20個測試樣本,一共1200個測試樣本。
嘗試改變條件,測得識別正確率如下:
四、kNN算法分析
由上部分結果可知:knn算法對於手寫字母的識別率並不理想。
原因可能有以下幾個方面:
1、圖片特征提取過於簡單,圖片邊緣較多空白,且圖片中字母的中心位置未必全部對應
2、因為英文有些字母大小寫比較相似,容易識別錯誤
3、樣本規模較小,每個字符最多只有300個訓練樣本,真正的訓練需要海量數據
在后序的文章中嘗試用其他學習算法提高分類識別率。各位道友有更好的意見也歡迎提出!