最鄰近規則分類 K-Nearest Neighbor
步驟:
1、為了判斷未知實例的類別,以所有已知類別的實例作為參考。
2、選擇參數K。
3、計算未知實例與所有已知實例的距離。
4、選擇最近的K個已知實例。
5、根據少數服從多數,讓未知實例歸類為K個最鄰近樣本中最多數的類別。
優點:簡單,易於理解,容易實現,通過對K的選擇可具備丟噪音數據的強壯性。
缺點:1、需要大量空間存儲所有已知實例。2、當樣本分布不均衡時,比如其中一類樣本實例數量過多,占主導的時候,新的未知實例很容易被歸類這個主導樣本。
改進:考慮距離,根據距離加上權重。
KNN算法實現:
import csv
import math
import random
# 從數據文件中獲取訓練集和測試集
def loadDataset(filename, split, trainingSet=[], testSet=[]):
with open(filename, 'r') as file:
lines = csv.reader(file)
data = list(lines)
for x in range(len(data) - 1):
for y in range(4):
data[x][y] = float(data[x][y])
if random.random() < split:
trainingSet.append(data[x])
else:
testSet.append(data[x])
# 計算兩個實例間的距離
def getDistance(instance1, instance2, length):
distance = 0
for x in range(length):
distance += math.pow(instance1[x] - instance2[x], 2)
return math.sqrt(distance)
# 獲取測試集單個實例附近k范圍的所有實例
def getNeighbors(trainingSet, testInstance, k):
distances = []
length = len(testInstance) - 1
for x in range(len(trainingSet)):
d = getDistance(trainingSet[x], testInstance, length)
distances.append((trainingSet[x], d))
newDistances = sorted(distances, key=lambda x: x[1])
neighbors = []
for x in range(k):
neighbors.append(newDistances[x][0])
return neighbors
# 通過獲得的K范圍內所有實例,獲得實例中最多的類別屬於哪一類
def getResponse(neighbors):
classDict = {} # 定義一個字典用於統計每個類別的個數
for x in range(len(neighbors)):
response = neighbors[x][-1]
if response in classDict:
classDict[response] += 1
else:
classDict[response] = 1
newClassDict = sorted(classDict.items(), key=lambda x: x[1], reverse=True)
return newClassDict
# 計算預測正確的概率
def getAccuracy(testSet, predictions):
correct = 0
for x in range(len(testSet)):
if testSet[x][-1] == predictions[x][0][0]:
correct += 1
return (correct / float(len(testSet)))
def main():
trainingSet = []
testSet = []
split = 0.8
predictions = []
loadDataset('D:\daacheng\Python\PythonCode\machineLearning\irisdata.txt', split, trainingSet, testSet)
print('Train set: ' + repr(len(trainingSet)))
print('Test set: ' + repr(len(testSet)))
k = 5
for x in range(len(testSet)):
neighbors = getNeighbors(trainingSet, testSet[x], k)
result = getResponse(neighbors)
predictions.append(result)
print(testSet)
print('-----------------------------------------------')
print(predictions)
correct = getAccuracy(testSet, predictions)
print(correct)
if __name__ == '__main__':
main()