最近有一个需求,在地图上,将客户按照距离进行聚合。比如,a客户到b客户5km,b客户到c客户5km,那么abc就可以聚合成一个集合。首先想到的就是找一个根据坐标来聚合的算法,这里找了一些后来选择了较为简单也符合要求的DBScan聚类算法。
它是一种基于密度的聚类算法,简单来说就是根据样本的紧密程度和数量将其分成多个集合。这个样本一般来说是一堆坐标点。参数可以为欧式距离和邻域密度阈值(就是每次寻找相邻的点的最低数量)。最终返回多个样本集合。
2.java实现
坐标点:这个类如果测试的话,只用到里面的point坐标点这个属性
import java.util.Collection; import org.apache.commons.math.stat.clustering.Clusterable; import org.apache.commons.math.util.MathUtils; import bsh.This; /** * @author xjx * */ public class CustomerPoint implements Clusterable<CustomerPoint>{ private String sender; private String sender_addr; private int value; private final double[] point; public int getValue() { return value; } public void setValue(int value) { this.value = value; } public String getSender() { return sender; } public void setSender(String sender) { this.sender = sender; } public String getSender_addr() { return sender_addr; } public void setSender_addr(String sender_addr) { this.sender_addr = sender_addr; } public CustomerPoint(final double[] point) { this.point = point; } public double[] getPoint() { return point; } public double distanceFrom(final CustomerPoint p) { return MathUtils.distance(point, p.getPoint()); } public CustomerPoint centroidOf(final Collection<CustomerPoint> points) { double[] centroid = new double[getPoint().length]; for (CustomerPoint p : points) { for (int i = 0; i < centroid.length; i++) { centroid[i] += p.getPoint()[i]; } } for (int i = 0; i < centroid.length; i++) { centroid[i] /= points.size(); } return new CustomerPoint(centroid); } @Override public boolean equals(final Object other) { if (!(other instanceof CustomerPoint)) { return false; } final double[] otherPoint = ((CustomerPoint) other).getPoint(); if (point.length != otherPoint.length) { return false; } for (int i = 0; i < point.length; i++) { if (point[i] != otherPoint[i]) { return false; } } return true; } @Override public String toString() { final StringBuffer buff = new StringBuffer("{"); final double[] coordinates = getPoint(); buff.append("lat:"+coordinates[0]+","); buff.append("lng:"+coordinates[1]+","); buff.append("value:"+this.getValue()); buff.append("}"); return buff.toString(); } }
2.算法实现和测试:
import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; import org.apache.commons.math3.util.MathUtils; import ...CustomerPoint; /** * * @author xjx * */ public class DBScanTest3{ //欧式距离 private final double distance; //最低要求的寻找邻居数量 private final int minPoints; private final Map<CustomerPoint, PointStatus> visited = new HashMap<CustomerPoint, PointStatus>(); //点的标记,point:聚合内的点,noise:噪音点 private enum PointStatus { NOISE,POINT } public DBScanTest3(final double distance, final int minPoints) throws Exception { if (distance < 0.0d) { throw new Exception("距离小于0"); } if (minPoints < 0) { throw new Exception("点数小于0"); } this.distance = distance; this.minPoints = minPoints; } public double getDistance() { return distance; } public int getMinPoints() { return minPoints; } public Map<CustomerPoint, PointStatus> getVisited() { return visited; } /** * 返回customerPoint的多个聚合 * @param points * @return */ public List<List<CustomerPoint>> cluster(List<CustomerPoint> points){ final List<List<CustomerPoint>> clusters = new ArrayList<List<CustomerPoint>>(); for (CustomerPoint point : points) {
//如果已经被标记 if (visited.get(point) != null) { continue; } List<CustomerPoint> neighbors = getNeighbors(point, points); if (neighbors.size() >= minPoints) { visited.put(point, PointStatus.POINT); List<CustomerPoint> cluster = new ArrayList<CustomerPoint>(); //遍历所有邻居继续拓展找点 clusters.add(expandCluster(cluster, point, neighbors, points, visited)); } else { visited.put(point, PointStatus.NOISE); } } return clusters; } private List<CustomerPoint> expandCluster( List<CustomerPoint> cluster, CustomerPoint point, List<CustomerPoint> neighbors, List<CustomerPoint> points, Map<CustomerPoint, PointStatus> visited) { cluster.add(point); visited.put(point, PointStatus.POINT); int index = 0; //遍历 所有的邻居 while (index < neighbors.size()) { //移动当前的点 CustomerPoint current = neighbors.get(index); PointStatus pStatus = visited.get(current); if (pStatus == null) { List<CustomerPoint> currentNeighbors = getNeighbors(current, points); neighbors.addAll(currentNeighbors); }
//如果该点未被标记,将点进行标记并加入到集合中 if (pStatus != PointStatus.POINT) { visited.put(current, PointStatus.POINT); cluster.add(current); } index++; } return cluster; } //找到所有的邻居 private List<CustomerPoint> getNeighbors(CustomerPoint point,List<CustomerPoint> points) { List<CustomerPoint> neighbors = new ArrayList<CustomerPoint>(); for (CustomerPoint neighbor : points) { if (visited.get(neighbor) != null) { continue; } if (point != neighbor && neighbor.distanceFrom(point) <= distance) { neighbors.add(neighbor); } } return neighbors; }
//做数据进行测试 public static void main(String[] args) throws Exception { CustomerPoint customerPoint = new CustomerPoint(new double[] {3,8}); CustomerPoint customerPoint1 = new CustomerPoint(new double[] {4,7}); CustomerPoint customerPoint2 = new CustomerPoint(new double[] {4,8}); CustomerPoint customerPoint3 = new CustomerPoint(new double[] {5,6}); CustomerPoint customerPoint4 = new CustomerPoint(new double[] {3,9}); CustomerPoint customerPoint5 = new CustomerPoint(new double[] {5,1}); CustomerPoint customerPoint6 = new CustomerPoint(new double[] {5,2}); CustomerPoint customerPoint7 = new CustomerPoint(new double[] {6,3}); CustomerPoint customerPoint8 = new CustomerPoint(new double[] {7,3}); CustomerPoint customerPoint9 = new CustomerPoint(new double[] {7,4}); CustomerPoint customerPoint10 = new CustomerPoint(new double[] {0,2}); CustomerPoint customerPoint11 = new CustomerPoint(new double[] {8,16}); CustomerPoint customerPoint12 = new CustomerPoint(new double[] {1,1}); CustomerPoint customerPoint13 = new CustomerPoint(new double[] {1,3}); List<CustomerPoint> cs = new ArrayList<>(); cs.add(customerPoint13); cs.add(customerPoint12); cs.add(customerPoint11); cs.add(customerPoint10); cs.add(customerPoint9); cs.add(customerPoint8); cs.add(customerPoint7); cs.add(customerPoint6); cs.add(customerPoint5); cs.add(customerPoint4); cs.add(customerPoint3); cs.add(customerPoint2); cs.add(customerPoint1); cs.add(customerPoint);
//这里第一个参数为距离,第二个参数为最小邻居数量 DBScanTest3 db = new DBScanTest3(1.5, 1);
//返回结果并打印 List<List<CustomerPoint>> aa =db.cluster(cs); for(int i =0;i<aa.size();i++) { for(int j=0;j<aa.get(i).size();j++) { System.out.print(aa.get(i).get(j).toString()); } System.out.println(); } } }
结果打印:
{lat:1.0,lng:3.0,value:0}{lat:0.0,lng:2.0,value:0}{lat:1.0,lng:1.0,value:0} {lat:7.0,lng:4.0,value:0}{lat:7.0,lng:3.0,value:0}{lat:6.0,lng:3.0,value:0}{lat:5.0,lng:2.0,value:0}{lat:5.0,lng:1.0,value:0} {lat:3.0,lng:9.0,value:0}{lat:4.0,lng:8.0,value:0}{lat:3.0,lng:8.0,value:0}{lat:4.0,lng:7.0,value:0}{lat:5.0,lng:6.0,value:0}
这里返回3个集合,其余的为噪音点,读者可以将这些坐标点画在网格图上,可以看到它们分为3部分,每一部分的点距离都小于1.5。