1、K-近鄰算法(Knn)
其原理為在一個樣本空間中,有一些已知分類的樣本,當出現一個未知分類的樣本,則根據距離這個未知樣本最近的k個樣本來決定。
舉例:愛情電影和動作電影,它們中都存在吻戲和動作,出現一個未知分類的電影,將根據以吻戲數量和動作數量建立的坐標系中距離未知分類所在點的最近的k個點來決定。
2、算法實現步驟
(1)計算所有點距離未知點的歐式距離
(2)對所有點進行排序
(3)找到距離未知點最近的k個點
(4)計算這k個點所在分類出現的頻率
(5)選擇頻率最大的分類即為未知點的分類
3、java實現
Point類
public class Point { private long id; private double x; private double y; private String type; public Point(long id,double x, double y) { this.x = x; this.y = y; this.id = id; } public Point(long id,double x, double y, String type) { this.x = x; this.y = y; this.type = type; this.id = id; } //get、set方法省略 }
Distance類
public class Distance { // 已知點id private long id; // 未知點id private long nid; // 二者之間的距離 private double disatance; public Distance(long id, long nid, double disatance) { this.id = id; this.nid = nid; this.disatance = disatance; } //get、set方法省略 }
比較器CompareClass類
import java.util.Comparator; //比較器類 public class CompareClass implements Comparator<Distance>{ public int compare(Distance d1, Distance d2) { return d1.getDisatance()>d2.getDisatance()?20 : -1; } }
KNN主類
/** * 1、輸入所有已知點 2、輸入未知點 3、計算所有已知點到未知點的歐式距離 4、根據距離對所有已知點排序 5、選出距離未知點最近的k個點 6、計算k個點所在分類出現的頻率 7、選擇頻率最大的類別即為未知點的類別 * * @author fzj * */ public class KNN { public static void main(String[] args) { // 一、輸入所有已知點 List<Point> dataList = creatDataSet(); // 二、輸入未知點 Point x = new Point(5, 1.2, 1.2); // 三、計算所有已知點到未知點的歐式距離,並根據距離對所有已知點排序 CompareClass compare = new CompareClass(); Set<Distance> distanceSet = new TreeSet<Distance>(compare); for (Point point : dataList) { distanceSet.add(new Distance(point.getId(), x.getId(), oudistance(point, x))); } // 四、選取最近的k個點 double k = 5; /** * 五、計算k個點所在分類出現的頻率 */ // 1、計算每個分類所包含的點的個數 List<Distance> distanceList= new ArrayList<Distance>(distanceSet); Map<String, Integer> map = getNumberOfType(distanceList, dataList, k); // 2、計算頻率 Map<String, Double> p = computeP(map, k); x.setType(maxP(p)); System.out.println("未知點的類型為:"+x.getType()); } // 歐式距離計算 public static double oudistance(Point point1, Point point2) { double temp = Math.pow(point1.getX() - point2.getX(), 2) + Math.pow(point1.getY() - point2.getY(), 2); return Math.sqrt(temp); } // 找出最大頻率 public static String maxP(Map<String, Double> map) { String key = null; double value = 0.0; for (Map.Entry<String, Double> entry : map.entrySet()) { if (entry.getValue() > value) { key = entry.getKey(); value = entry.getValue(); } } return key; } // 計算頻率 public static Map<String, Double> computeP(Map<String, Integer> map, double k) { Map<String, Double> p = new HashMap<String, Double>(); for (Map.Entry<String, Integer> entry : map.entrySet()) { p.put(entry.getKey(), entry.getValue() / k); } return p; } // 計算每個分類包含的點的個數 public static Map<String, Integer> getNumberOfType( List<Distance> listDistance, List<Point> listPoint, double k) { Map<String, Integer> map = new HashMap<String, Integer>(); int i = 0; System.out.println("選取的k個點,由近及遠依次為:"); for (Distance distance : listDistance) { System.out.println("id為" + distance.getId() + ",距離為:" + distance.getDisatance()); long id = distance.getId(); // 通過id找到所屬類型,並存儲到HashMap中 for (Point point : listPoint) { if (point.getId() == id) { if (map.get(point.getType()) != null) map.put(point.getType(), map.get(point.getType()) + 1); else { map.put(point.getType(), 1); } } } i++; if (i >= k) break; } return map; } public static ArrayList<Point> creatDataSet(){ Point point1 = new Point(1, 1.0, 1.1, "A"); Point point2 = new Point(2, 1.0, 1.0, "A"); Point point3 = new Point(3, 1.0, 1.2, "A"); Point point4 = new Point(4, 0, 0, "B"); Point point5 = new Point(5, 0, 0.1, "B"); Point point6 = new Point(6, 0, 0.2, "B"); ArrayList<Point> dataList = new ArrayList<Point>(); dataList.add(point1); dataList.add(point2); dataList.add(point3); dataList.add(point4); dataList.add(point5); dataList.add(point6); return dataList; } }
4、運行結果
參考
[1] 《機器學習實戰》