kdtree java_使用kd-tree加速k-means

packagecc;importjava.util.ArrayList;importjava.util.HashMap;public classMRKDTree {privateNode mrkdtree;private classNode{//分割的维度

intpartitionDimention;//分割的值

doublepartitionValue;//如果为非叶子节点,该属性为空//否则为数据

double[] value;//是否为叶子

boolean isLeaf=false;//左树

Node left;//右树

Node right;//每个维度的最小值

double[] min;//每个维度的最大值

double[] max;double[] sumOfPoints;intn;

}private static classUtilZ{/*** 计算给定维度的方差

*@paramdata 数据

*@paramdimention 维度

*@return方差*/

static double variance(ArrayList data,intdimention){double vsum = 0;double sum = 0;for(double[] d:data){

sum+=d[dimention];

vsum+=d[dimention]*d[dimention];

}int n =data.size();return vsum/n-Math.pow(sum/n, 2);

}/*** 取排序后的中间位置数值

*@paramdata 数据

*@paramdimention 维度

*@return

*/

static double median(ArrayList data,intdimention){double[] d =new double[data.size()];int i=0;for(double[] k:data){

d[i++]=k[dimention];

}returnmedian(d);

}private static double median(double[] a){int n=a.length;int L = 0;int R = n - 1;int k = n / 2;inti;intj;while (L

i=L;

j=R;do{while (a[i]

i++;while (x

j--;if (i <=j) {double t =a[i];

a[i]=a[j];

a[j]=t;

i++;

j--;

}

}while (i <=j);if (j

L=i;if (k

R=j;

}returna[k];

}static double[][] maxmin(ArrayList data,intdimentions){double[][] mm = new double[2][dimentions];//初始化 第一行为min,第二行为max

for(int i=0;i

mm[0][i]=mm[1][i]=data.get(0)[i];for(int j=1;j

mm[0][i]=d[i];

}else if(d[i]>mm[1][i]){

mm[1][i]=d[i];

}

}

}returnmm;

}static double distance(double[] a,double[] b){double sum = 0;for(int i=0;i

sum+=Math.pow(a[i]-b[i], 2);

}returnsum;

}/*** 在max和min表示的超矩形中的点和点a的最小距离

*@parama 点a

*@parammax 超矩形各个维度的最大值

*@parammin 超矩形各个维度的最小值

*@return超矩形中的点和点a的最小距离*/

static double mindistance(double[] a,double[] max,double[] min){double sum = 0;for(int i=0;imax[i])

sum+= Math.pow(a[i]-max[i], 2);else if (a[i]

sum+= Math.pow(min[i]-a[i], 2);

}

}returnsum;

}public static double[] sumOfPoints(ArrayListdata,intdimentions) {double[] res = new double[dimentions];for(double[] d:data){for(int i=0;i

res[i]+=d[i];

}

}returnres;

}/*** 判断centerd是否在h上优于c

*@paramcenterd

*@paramc

*@parammax

*@parammin

*@return

*/

public static boolean isOver(double[] center, double[] c,double[] max, double[] min) {double discenter = 0;double disc = 0;for(int i=0;i0){

disc+=Math.pow(max[i]-c[i],2);

discenter+=Math.pow(max[i]-center[i],2);

}else if(c[i]-center[i]<0) {

disc+=Math.pow(min[i]-c[i],2);

discenter+=Math.pow(min[i]-center[i],2);

}

}return discenter

}

}privateMRKDTree() {}/*** 构建树

*@paraminput 输入

*@returnKDTree树*/

public static MRKDTree build(double[][] input){int n =input.length;int m = input[0].length;

ArrayList data =new ArrayList(n);for(int i=0;i

d[j]=input[i][j];

data.add(d);

}

MRKDTree tree= newMRKDTree();

tree.mrkdtree= tree.newNode();

tree.buildDetail(tree.mrkdtree, data, m,0);returntree;

}/*** 循环构建树

*@paramnode 节点

*@paramdata 数据

*@paramdimentions 数据的维度*/

private void buildDetail(Node node,ArrayList data,int dimentions,intlv){if(data.size()==1){

node.isLeaf=true;

node.value=data.get(0);return;

}//选择方差最大的维度

/*node.partitionDimention=-1;

double var = -1;

double tmpvar;

for(int i=0;i

tmpvar=UtilZ.variance(data, i);

if (tmpvar>var){

var = tmpvar;

node.partitionDimention = i;

}

}

//如果方差=0,表示所有数据都相同,判定为叶子节点

if(var<1e-10){

node.isLeaf=true;

node.value=data.get(0);

return;

}*/

double[][] maxmin=UtilZ.maxmin(data, dimentions);

node.min= maxmin[0];

node.max= maxmin[1];//选取方差大的维度,会需要很长时间//改成使用选取数据范围最大的维度//这样构建kdtree的速度会变快,但是在kmean更新中心点会变慢

boolean isleaf = true;for(int i=0;i

isleaf=false;break;

}if(isleaf){

node.isLeaf=true;

node.value=data.get(0);return;

}

node.partitionDimention=-1;double diff = -1;doubletmpdiff;for(int i=0;i

tmpdiff=node.max[i]-node.min[i];if (tmpdiff>diff){

diff=tmpdiff;

node.partitionDimention=i;

}

}

node.sumOfPoints=UtilZ.sumOfPoints(data,dimentions);

node.n=data.size();//选择分割的值

node.partitionValue=UtilZ.median(data, node.partitionDimention);if(node.partitionValue==node.min[node.partitionDimention]){

node.partitionValue+=1e-5;

}int size = (int)(data.size()*0.55);

ArrayList left = new ArrayList(size);

ArrayList right = new ArrayList(size);for(double[] d:data){if (d[node.partitionDimention]

left.add(d);

}else{

right.add(d);

}

}

Node leftnode= newNode();

Node rightnode= newNode();

node.left=leftnode;

node.right=rightnode;

buildDetail(leftnode, left, dimentions,lv+1);

buildDetail(rightnode, right, dimentions,lv+1);

}public double[][] updateCentroids(double[][] cs){int k =cs.length;int m = cs[0].length;double[][] entroids = new double[k][m];int[] datacount = new int[k];

HashMap cscopy = new HashMap();for(int i=0;i

cscopy.put(i, cs[i]);

updateCentroidsDetail(mrkdtree,cscopy,entroids,datacount,k,m);double[][] csnew = new double[k][m];for(int i=0;i

csnew[i][j]=entroids[i][j]/datacount[i];

}

}returncsnew;

}private voidupdateCentroidsDetail(Node node,

HashMap cs, double[][] entroids,int[] datacount,int k,intm) {//如果是叶子节点

if(node.isLeaf){double[] v=node.value;double dis=Double.MAX_VALUE;doubletdis;int index = -1;//找到所属的中心点

for(Integer i: cs.keySet()){double[] c =cs.get(i);

tdis=UtilZ.distance(c, v);if(tdis

dis=tdis;

index=i;

}

}//更新统计信息

datacount[index]++;for(int i=0;i

entroids[index][i]+=v[i];

}return;

}double[] stack = new double[k];int stackpoint = 0;int center=0;doubletdis;for(Integer i: cs.keySet()){double[] c =cs.get(i);

tdis=UtilZ.mindistance(c, node.max, node.min);if(stackpoint==0){

stack[stackpoint++]=tdis;

center=i;

}else if (tdis

stackpoint=1;

stack[0]=tdis;

center=i;

}else if (tdis==stack[stackpoint-1]) {

stack[stackpoint++]=tdis;

}

}//stackpoint>1,说明有多个最小值,不存在中心点

if(stackpoint!=1){

updateCentroidsDetail(node.left, cs, entroids, datacount, k, m);

updateCentroidsDetail(node.right, cs, entroids, datacount, k, m);return;

}

HashMap ctover = new HashMap();double[] centerd =cs.get(center);for(Integer i: cs.keySet()){if(i==center) continue;double[] c =cs.get(i);if(UtilZ.isOver(centerd,c,node.max,node.min)){

ctover.put(i,true);

}

}if(ctover.size()==cs.size()-1){//此时中心点即为center,更新信息

datacount[center]+=node.n;for(int i=0;i

entroids[center][i]+=node.sumOfPoints[i];

}return;

}//将其比center差的中心点排除

HashMap csnew = new HashMap();for(Integer i:cs.keySet()){if(!ctover.containsKey(i))

csnew.put(i, cs.get(i));

}

updateCentroidsDetail(node.left, csnew, entroids, datacount, k, m);

updateCentroidsDetail(node.right, csnew, entroids, datacount, k, m);

}

}

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
以下是三种不同的k-means算法的实现: 1. 基本k-means算法 ``` # 输入数据集X和聚类数k # 输出聚类结果clusters和聚类中心centroids def k_means(X, k): # 初始化聚类中心 centroids = X[:k] # 初始化聚类结果 clusters = [[] for i in range(k)] # 迭代聚类过程 while True: # 将每个数据点分配到最近的聚类中心 for x in X: distances = [(x - c) ** 2 for c in centroids] cluster_index = distances.index(min(distances)) clusters[cluster_index].append(x) # 更新聚类中心 new_centroids = [] for cluster in clusters: if cluster: new_centroids.append(sum(cluster) / len(cluster)) else: new_centroids.append(centroids[clusters.index(cluster)]) # 判断是否收敛 if new_centroids == centroids: break else: centroids = new_centroids clusters = [[] for i in range(k)] return clusters, centroids ``` 2. 加速k-means算法(使用k-d树) ``` import numpy as np from sklearn.neighbors import KDTree # 输入数据集X和聚类数k # 输出聚类结果clusters和聚类中心centroids def k_means_kd(X, k): # 初始化聚类中心 centroids = X[:k] # 初始化聚类结果 clusters = [[] for i in range(k)] # 构建k-d树 tree = KDTree(X) # 迭代聚类过程 while True: # 将每个数据点分配到最近的聚类中心 for x in X: distances, indices = tree.query([x], k=1) cluster_index = np.argmin(distances) clusters[cluster_index].append(x) # 更新聚类中心 new_centroids = [] for cluster in clusters: if cluster: new_centroids.append(sum(cluster) / len(cluster)) else: new_centroids.append(centroids[clusters.index(cluster)]) # 判断是否收敛 if new_centroids == centroids: break else: centroids = new_centroids clusters = [[] for i in range(k)] return clusters, centroids ``` 3. mini-batch k-means算法 ``` import numpy as np # 输入数据集X、聚类数k和批次大小batch_size # 输出聚类结果clusters和聚类中心centroids def mini_batch_k_means(X, k, batch_size): # 初始化聚类中心 centroids = X[:k] # 初始化聚类结果 clusters = [[] for i in range(k)] # 迭代聚类过程 while True: # 随机选择一个批次 batch = np.random.choice(X, batch_size, replace=False) # 将每个数据点分配到最近的聚类中心 for x in batch: distances = [(x - c) ** 2 for c in centroids] cluster_index = distances.index(min(distances)) clusters[cluster_index].append(x) # 更新聚类中心 new_centroids = [] for cluster in clusters: if cluster: new_centroids.append(sum(cluster) / len(cluster)) else: new_centroids.append(centroids[clusters.index(cluster)]) # 判断是否收敛 if new_centroids == centroids: break else: centroids = new_centroids clusters = [[] for i in range(k)] return clusters, centroids ```

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值