最近有一个需求,在地图上,将客户按照距离进行聚合。比如,a客户到b客户5km,b客户到c客户5km,那么abc就可以聚合成一个集合。首先想到的就是找一个根据坐标来聚合的算法,这里找了一些后来选择了较为简单也符合要求的DBScan聚类算法。
它是一种基于密度的聚类算法,简单来说就是根据样本的紧密程度和数量将其分成多个集合。这个样本一般来说是一堆坐标点。参数可以为欧式距离和邻域密度阈值(就是每次寻找相邻的点的最低数量)。最终返回多个样本集合。
java实现
坐标点:这个类如果测试的话,只用到里面的point坐标点这个属性
import java.util.Collection;
import org.apache.commons.math.stat.clustering.Clusterable;
import org.apache.commons.math.util.MathUtils;
import bsh.This;
/**
* @author xjx
*
*/
public class CustomerPoint implements Clusterable<CustomerPoint>{
private String sender;
private String sender_addr;
private int value;
private final double[] point;
public int getValue() {
return value;
}
public void setValue(int value) {
this.value = value;
}
public String getSender() {
return sender;
}
public void setSender(String sender) {
this.sender = sender;
}
public String getSender_addr() {
return sender_addr;
}
public void setSender_addr(String sender_addr) {
this.sender_addr = sender_addr;
}
public CustomerPoint(final double[] point) {
this.point = point;
}
public double[] getPoint() {
return point;
}
public double distanceFrom(final CustomerPoint p) {
return MathUtils.distance(point, p.getPoint());
}
public CustomerPoint centroidOf(final Collection<CustomerPoint> points) {
double[] centroid = new double[getPoint().length];
for (CustomerPoint p : points) {
for (int i = 0; i < centroid.length; i++) {
centroid[i] += p.getPoint()[i];
}
}
for (int i = 0; i < centroid.length; i++) {
centroid[i] /= points.size();
}
return new CustomerPoint(centroid);
}
@Override
public boolean equals(final Object other) {
if (!(other instanceof CustomerPoint)) {
return false;
}
final double[] otherPoint = ((CustomerPoint) other).getPoint();
if (point.length != otherPoint.length) {
return false;
}
for (int i = 0; i < point.length; i++) {
if (point[i] != otherPoint[i]) {
return false;
}
}
return true;
}
@Override
public String toString() {
final StringBuffer buff = new StringBuffer("{");
final double[] coordinates = getPoint();
buff.append("lat:"+coordinates[0]+",");
buff.append("lng:"+coordinates[1]+",");
buff.append("value:"+this.getValue());
buff.append("}");
return buff.toString();
}
}
算法实现和测试:
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.commons.math3.util.MathUtils;
import ...CustomerPoint;
/**
*
* @author xjx
*
*/
public class DBScanTest3{
//欧式距离
private final double distance;
//最低要求的寻找邻居数量
private final int minPoints;
private final Map<CustomerPoint, PointStatus> visited = new HashMap<CustomerPoint, PointStatus>();
//点的标记,point:聚合内的点,noise:噪音点
private enum PointStatus {
NOISE,POINT
}
public DBScanTest3(final double distance, final int minPoints)
throws Exception {
if (distance < 0.0d) {
throw new Exception("距离小于0");
}
if (minPoints < 0) {
throw new Exception("点数小于0");
}
this.distance = distance;
this.minPoints = minPoints;
}
public double getDistance() {
return distance;
}
public int getMinPoints() {
return minPoints;
}
public Map<CustomerPoint, PointStatus> getVisited() {
return visited;
}
/**
* 返回customerPoint的多个聚合
* @param points
* @return
*/
public List<List<CustomerPoint>> cluster(List<CustomerPoint> points){
final List<List<CustomerPoint>> clusters = new ArrayList<List<CustomerPoint>>();
for (CustomerPoint point : points) { //如果已经被标记
if (visited.get(point) != null) {
continue;
}
List<CustomerPoint> neighbors = getNeighbors(point, points);
if (neighbors.size() >= minPoints) {
visited.put(point, PointStatus.POINT);
List<CustomerPoint> cluster = new ArrayList<CustomerPoint>();
//遍历所有邻居继续拓展找点
clusters.add(expandCluster(cluster, point, neighbors, points, visited));
} else {
visited.put(point, PointStatus.NOISE);
}
}
return clusters;
}
private List<CustomerPoint> expandCluster( List<CustomerPoint> cluster,
CustomerPoint point,
List<CustomerPoint> neighbors,
List<CustomerPoint> points,
Map<CustomerPoint, PointStatus> visited) {
cluster.add(point);
visited.put(point, PointStatus.POINT);
int index = 0;
//遍历 所有的邻居
while (index < neighbors.size()) {
//移动当前的点
CustomerPoint current = neighbors.get(index);
PointStatus pStatus = visited.get(current);
if (pStatus == null) {
List<CustomerPoint> currentNeighbors = getNeighbors(current, points);
neighbors.addAll(currentNeighbors);
} //如果该点未被标记,将点进行标记并加入到集合中
if (pStatus != PointStatus.POINT) {
visited.put(current, PointStatus.POINT);
cluster.add(current);
}
index++;
}
return cluster;
}
//找到所有的邻居
private List<CustomerPoint> getNeighbors(CustomerPoint point,List<CustomerPoint> points) {
List<CustomerPoint> neighbors = new ArrayList<CustomerPoint>();
for (CustomerPoint neighbor : points) {
if (visited.get(neighbor) != null) {
continue;
}
if (point != neighbor && neighbor.distanceFrom(point) <= distance) {
neighbors.add(neighbor);
}
}
return neighbors;
}
//做数据进行测试
public static void main(String[] args) throws Exception {
CustomerPoint customerPoint = new CustomerPoint(new double[] {3,8});
CustomerPoint customerPoint1 = new CustomerPoint(new double[] {4,7});
CustomerPoint customerPoint2 = new CustomerPoint(new double[] {4,8});
CustomerPoint customerPoint3 = new CustomerPoint(new double[] {5,6});
CustomerPoint customerPoint4 = new CustomerPoint(new double[] {3,9});
CustomerPoint customerPoint5 = new CustomerPoint(new double[] {5,1});
CustomerPoint customerPoint6 = new CustomerPoint(new double[] {5,2});
CustomerPoint customerPoint7 = new CustomerPoint(new double[] {6,3});
CustomerPoint customerPoint8 = new CustomerPoint(new double[] {7,3});
CustomerPoint customerPoint9 = new CustomerPoint(new double[] {7,4});
CustomerPoint customerPoint10 = new CustomerPoint(new double[] {0,2});
CustomerPoint customerPoint11 = new CustomerPoint(new double[] {8,16});
CustomerPoint customerPoint12 = new CustomerPoint(new double[] {1,1});
CustomerPoint customerPoint13 = new CustomerPoint(new double[] {1,3});
List<CustomerPoint> cs = new ArrayList<>();
cs.add(customerPoint13);
cs.add(customerPoint12);
cs.add(customerPoint11);
cs.add(customerPoint10);
cs.add(customerPoint9);
cs.add(customerPoint8);
cs.add(customerPoint7);
cs.add(customerPoint6);
cs.add(customerPoint5);
cs.add(customerPoint4);
cs.add(customerPoint3);
cs.add(customerPoint2);
cs.add(customerPoint1);
cs.add(customerPoint); //这里第一个参数为距离,第二个参数为最小邻居数量
DBScanTest3 db = new DBScanTest3(1.5, 1); //返回结果并打印
List<List<CustomerPoint>> aa =db.cluster(cs);
for(int i =0;i<aa.size();i++) {
for(int j=0;j<aa.get(i).size();j++) {
System.out.print(aa.get(i).get(j).toString());
}
System.out.println();
}
}
}
结果打印:
{lat:1.0,lng:3.0,value:0}{lat:0.0,lng:2.0,value:0}{lat:1.0,lng:1.0,value:0}
{lat:7.0,lng:4.0,value:0}{lat:7.0,lng:3.0,value:0}{lat:6.0,lng:3.0,value:0}{lat:5.0,lng:2.0,value:0}{lat:5.0,lng:1.0,value:0}
{lat:3.0,lng:9.0,value:0}{lat:4.0,lng:8.0,value:0}{lat:3.0,lng:8.0,value:0}{lat:4.0,lng:7.0,value:0}{lat:5.0,lng:6.0,value:0}
这里返回3个集合,其余的为噪音点,读者可以将这些坐标点画在网格图上,可以看到它们分为3部分,每一部分的点距离都小于1.5