packagecc;importjava.util.ArrayList;importjava.util.HashMap;public classMRKDTree {privateNode mrkdtree;private classNode{//分割的维度
intpartitionDimention;//分割的值
doublepartitionValue;//如果为非叶子节点,该属性为空//否则为数据
double[] value;//是否为叶子
boolean isLeaf=false;//左树
Node left;//右树
Node right;//每个维度的最小值
double[] min;//每个维度的最大值
double[] max;double[] sumOfPoints;intn;
}private static classUtilZ{/*** 计算给定维度的方差
*@paramdata 数据
*@paramdimention 维度
*@return方差*/
static double variance(ArrayList data,intdimention){double vsum = 0;double sum = 0;for(double[] d:data){
sum+=d[dimention];
vsum+=d[dimention]*d[dimention];
}int n =data.size();return vsum/n-Math.pow(sum/n, 2);
}/*** 取排序后的中间位置数值
*@paramdata 数据
*@paramdimention 维度
*@return
*/
static double median(ArrayList data,intdimention){double[] d =new double[data.size()];int i=0;for(double[] k:data){
d[i++]=k[dimention];
}returnmedian(d);
}private static double median(double[] a){int n=a.length;int L = 0;int R = n - 1;int k = n / 2;inti;intj;while (L
i=L;
j=R;do{while (a[i]
i++;while (x
j--;if (i <=j) {double t =a[i];
a[i]=a[j];
a[j]=t;
i++;
j--;
}
}while (i <=j);if (j
L=i;if (k
R=j;
}returna[k];
}static double[][] maxmin(ArrayList data,intdimentions){double[][] mm = new double[2][dimentions];//初始化 第一行为min,第二行为max
for(int i=0;i
mm[0][i]=mm[1][i]=data.get(0)[i];for(int j=1;j
mm[0][i]=d[i];
}else if(d[i]>mm[1][i]){
mm[1][i]=d[i];
}
}
}returnmm;
}static double distance(double[] a,double[] b){double sum = 0;for(int i=0;i
sum+=Math.pow(a[i]-b[i], 2);
}returnsum;
}/*** 在max和min表示的超矩形中的点和点a的最小距离
*@parama 点a
*@parammax 超矩形各个维度的最大值
*@parammin 超矩形各个维度的最小值
*@return超矩形中的点和点a的最小距离*/
static double mindistance(double[] a,double[] max,double[] min){double sum = 0;for(int i=0;imax[i])
sum+= Math.pow(a[i]-max[i], 2);else if (a[i]
sum+= Math.pow(min[i]-a[i], 2);
}
}returnsum;
}public static double[] sumOfPoints(ArrayListdata,intdimentions) {double[] res = new double[dimentions];for(double[] d:data){for(int i=0;i
res[i]+=d[i];
}
}returnres;
}/*** 判断centerd是否在h上优于c
*@paramcenterd
*@paramc
*@parammax
*@parammin
*@return
*/
public static boolean isOver(double[] center, double[] c,double[] max, double[] min) {double discenter = 0;double disc = 0;for(int i=0;i0){
disc+=Math.pow(max[i]-c[i],2);
discenter+=Math.pow(max[i]-center[i],2);
}else if(c[i]-center[i]<0) {
disc+=Math.pow(min[i]-c[i],2);
discenter+=Math.pow(min[i]-center[i],2);
}
}return discenter
}
}privateMRKDTree() {}/*** 构建树
*@paraminput 输入
*@returnKDTree树*/
public static MRKDTree build(double[][] input){int n =input.length;int m = input[0].length;
ArrayList data =new ArrayList(n);for(int i=0;i
d[j]=input[i][j];
data.add(d);
}
MRKDTree tree= newMRKDTree();
tree.mrkdtree= tree.newNode();
tree.buildDetail(tree.mrkdtree, data, m,0);returntree;
}/*** 循环构建树
*@paramnode 节点
*@paramdata 数据
*@paramdimentions 数据的维度*/
private void buildDetail(Node node,ArrayList data,int dimentions,intlv){if(data.size()==1){
node.isLeaf=true;
node.value=data.get(0);return;
}//选择方差最大的维度
/*node.partitionDimention=-1;
double var = -1;
double tmpvar;
for(int i=0;i
tmpvar=UtilZ.variance(data, i);
if (tmpvar>var){
var = tmpvar;
node.partitionDimention = i;
}
}
//如果方差=0,表示所有数据都相同,判定为叶子节点
if(var<1e-10){
node.isLeaf=true;
node.value=data.get(0);
return;
}*/
double[][] maxmin=UtilZ.maxmin(data, dimentions);
node.min= maxmin[0];
node.max= maxmin[1];//选取方差大的维度,会需要很长时间//改成使用选取数据范围最大的维度//这样构建kdtree的速度会变快,但是在kmean更新中心点会变慢
boolean isleaf = true;for(int i=0;i
isleaf=false;break;
}if(isleaf){
node.isLeaf=true;
node.value=data.get(0);return;
}
node.partitionDimention=-1;double diff = -1;doubletmpdiff;for(int i=0;i
tmpdiff=node.max[i]-node.min[i];if (tmpdiff>diff){
diff=tmpdiff;
node.partitionDimention=i;
}
}
node.sumOfPoints=UtilZ.sumOfPoints(data,dimentions);
node.n=data.size();//选择分割的值
node.partitionValue=UtilZ.median(data, node.partitionDimention);if(node.partitionValue==node.min[node.partitionDimention]){
node.partitionValue+=1e-5;
}int size = (int)(data.size()*0.55);
ArrayList left = new ArrayList(size);
ArrayList right = new ArrayList(size);for(double[] d:data){if (d[node.partitionDimention]
left.add(d);
}else{
right.add(d);
}
}
Node leftnode= newNode();
Node rightnode= newNode();
node.left=leftnode;
node.right=rightnode;
buildDetail(leftnode, left, dimentions,lv+1);
buildDetail(rightnode, right, dimentions,lv+1);
}public double[][] updateCentroids(double[][] cs){int k =cs.length;int m = cs[0].length;double[][] entroids = new double[k][m];int[] datacount = new int[k];
HashMap cscopy = new HashMap();for(int i=0;i
cscopy.put(i, cs[i]);
updateCentroidsDetail(mrkdtree,cscopy,entroids,datacount,k,m);double[][] csnew = new double[k][m];for(int i=0;i
csnew[i][j]=entroids[i][j]/datacount[i];
}
}returncsnew;
}private voidupdateCentroidsDetail(Node node,
HashMap cs, double[][] entroids,int[] datacount,int k,intm) {//如果是叶子节点
if(node.isLeaf){double[] v=node.value;double dis=Double.MAX_VALUE;doubletdis;int index = -1;//找到所属的中心点
for(Integer i: cs.keySet()){double[] c =cs.get(i);
tdis=UtilZ.distance(c, v);if(tdis
dis=tdis;
index=i;
}
}//更新统计信息
datacount[index]++;for(int i=0;i
entroids[index][i]+=v[i];
}return;
}double[] stack = new double[k];int stackpoint = 0;int center=0;doubletdis;for(Integer i: cs.keySet()){double[] c =cs.get(i);
tdis=UtilZ.mindistance(c, node.max, node.min);if(stackpoint==0){
stack[stackpoint++]=tdis;
center=i;
}else if (tdis
stackpoint=1;
stack[0]=tdis;
center=i;
}else if (tdis==stack[stackpoint-1]) {
stack[stackpoint++]=tdis;
}
}//stackpoint>1,说明有多个最小值,不存在中心点
if(stackpoint!=1){
updateCentroidsDetail(node.left, cs, entroids, datacount, k, m);
updateCentroidsDetail(node.right, cs, entroids, datacount, k, m);return;
}
HashMap ctover = new HashMap();double[] centerd =cs.get(center);for(Integer i: cs.keySet()){if(i==center) continue;double[] c =cs.get(i);if(UtilZ.isOver(centerd,c,node.max,node.min)){
ctover.put(i,true);
}
}if(ctover.size()==cs.size()-1){//此时中心点即为center,更新信息
datacount[center]+=node.n;for(int i=0;i
entroids[center][i]+=node.sumOfPoints[i];
}return;
}//将其比center差的中心点排除
HashMap csnew = new HashMap();for(Integer i:cs.keySet()){if(!ctover.containsKey(i))
csnew.put(i, cs.get(i));
}
updateCentroidsDetail(node.left, csnew, entroids, datacount, k, m);
updateCentroidsDetail(node.right, csnew, entroids, datacount, k, m);
}
}