基於OpenCV的KNN算法實現手寫數字識別


基於OpenCV的KNN算法實現手寫數字識別

一、數據預處理

# 導入所需模塊
import cv2
import numpy as np
import matplotlib.pyplot as plt
# 顯示灰度圖
def plt_show(img):
    plt.imshow(img,cmap='gray')
    plt.show()
# 加載數據集圖片數據
digits = cv2.imread('./image/digits.png',0)
print(digits.shape)
plt_show(digits)
(1000, 2000)

# 划分數據
cells = [np.hsplit(row,100) for row in np.vsplit(digits,50)] 

len(cells)
50
# 轉換為numpy數組
x = np.array(cells)
x.shape
(50, 100, 20, 20)
plt_show(x[5][0])

# 生成訓練數據標簽和測試數據標簽
k = np.arange(10)
train_label = np.repeat(k,250)
test_label = train_label.copy()
# 圖片數據轉換為特征矩陣,划分訓練數據集
train = x[:,:50].reshape(-1,400).astype(np.float32)
# 圖片數據轉換為特征矩陣,划分測試數據集
test = x[:,50:100].reshape(-1,400).astype(np.float32)
test.shape
(2500, 400)

二、knn算法預測

# 生成模型
knn = cv2.ml.KNearest_create()
# 訓練數據
knn.train(train,cv2.ml.ROW_SAMPLE,train_label)
True
# 傳入n值,和測試數據,返回結果
ret,result,neighbours,dist = knn.findNearest(test, 3)
# 統計正確的個數
res = 0
for i in range(2500):
    if result[i]==test_label[i]:
        res = res+1
res
2439
# 計算模型准確率
accuracy = res/result.size
print('識別測試數據的准確率為:',accuracy)
識別測試數據的准確率為: 0.9756

三、導入圖片預測

# 在測試集中隨便找一張圖片
test_image = test[2400].reshape(20,20)
plt_show(test_image)
test_label[2400]

# 將圖片轉換為特征矩陣
testImage = test[2400].reshape(-1,400).astype(np.float32)
testImage.shape
(1, 400)
# 使用訓練好的模型預測
ret,result,neighbours,dist = knn.findNearest(testImage, 3)
# 預測結果
print('識別出的數字為:',result[0][0])
識別出的數字為: 9.0
# 傳入一張自己找的圖片進行識別尺寸(20*20)
te = cv2.imread('test2.jpg',0)
plt_show(te)
te.shape

(20, 20)

testImage = te.reshape(-1,400).astype(np.float32)
testImage.shape
(1, 400)
ret,result,neighbours,dist = knn.findNearest(testImage, 3)
result
array([[2.]], dtype=float32)
print('識別出的數字為:',result[0][0])
識別出的數字為: 2.0

用自己寫的一張圖片預測

# 用所有數據作為訓練數據
knn = cv2.ml.KNearest_create()
k = np.arange(10)
labels = np.repeat(k,500)
knn.train(x.reshape(-1,400).astype(np.float32),cv2.ml.ROW_SAMPLE,labels)
True
te = cv2.imread('test1.jpg',0)
plt_show(te)
te.shape

(20, 20)

# 自適應閾值處理
ret, image = cv2.threshold(te, 0, 255, cv2.THRESH_OTSU | cv2.THRESH_BINARY_INV)
plt_show(image)

# 將圖片轉換為特征矩陣
testImage = image.reshape(-1,400).astype(np.float32)
testImage.shape
(1, 400)
# 使用訓練好的模型預測
ret,result,neighbours,dist = knn.findNearest(testImage, 3)
neighbours
array([[5., 5., 5.]], dtype=float32)
print('識別出的數字為:',result[0][0])
識別出的數字為: 5.0

資源地址:

鏈接:https://pan.baidu.com/s/1sUgKBvex43-Yf-Ul2DQSIA
提取碼:t1sd

視頻地址:https://www.bilibili.com/video/BV14A411t7tk/


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM