java k均值_java实现kmeans算法

该博客详细介绍了KMeans聚类算法的工作原理,并提供了一个Java实现的代码示例。KMeans通过迭代优化,将数据集分成指定数量的类别,直到中心点的变化达到阈值。文章还展示了如何生成模拟数据并运行KMeans算法进行分类。
摘要由CSDN通过智能技术生成

kmeans算法是一种经典的聚类算法,其核心思想是:根据给定的聚类个数k,随机选择k个点作为初始的中心节点,然后按照样本中其他节点与这k个节点的距离进行分类。每分类一次就重新计算一次k个中心节点,直到所有样本中的节点所属的分类不再变化为止。

代码:

public class KmeansAlgorithm {

private static final int T = 10; // 最大迭代次数

private static final double THRESHOLD = 0.1; // 中心节点位置变化大小的阈值

public ArrayList> getClusters(ArrayList> dataSet, int k) {

int dataDimension = 0;

if(null != dataSet && dataSet.size() < k) {

System.out.println("data size is smaller than the number to be clustered");

} else {

dataDimension = dataSet.get(0).size();

}

// 为每条数据赋初始类别0

for(int i = 0; i < dataSet.size(); i++) {

dataSet.get(i).add(0d);

}

// 随机从数据集中选注k个点作为初始的k个中心节点

ArrayList> centerData = new ArrayList>();

for(int i = 0; i < k; i++) {

centerData.add(dataSet.get(i));

}

for(int i = 0; i < T; i++) {

for(int j = 0; j < dataSet.size(); j++) {

double classify = 0; // classify取值为0到k-1代表k个类别

double minDistance = computeDistance(dataSet.get(j), centerData.get(0));

for(int l = 1; l < centerData.size(); l++) {

if(computeDistance(dataSet.get(j), centerData.get(l)) < minDistance) {

minDistance = computeDistance(dataSet.get(j), centerData.get(l));

classify = l;

}

}

dataSet.get(j).set(dataDimension, classify);

}

// 每次分类后计算中心节点的位置变化情况

double variance = computeChange(dataSet, centerData, k, dataDimension);

if(variance < THRESHOLD) {

break;

}

// 每次分类后重新计算中心节点

centerData = computeCenterData(dataSet, k, dataDimension);

}

return dataSet;

}

/**

*

* @Title: computeDistance

* @Description: 计算任意两个节点间的距离

* @return double

* @throws

*/

public double computeDistance(ArrayList d1, ArrayList d2) {

double squareSum = 0;

for(int i = 0; i < d1.size() - 1; i++) {

squareSum += (d1.get(i) - d2.get(i)) * (d1.get(i) - d2.get(i));

}

return Math.sqrt(squareSum);

}

/**

*

* @Title: computeCenterData

* @Description: 计算中心节点

* @return ArrayList

* @throws

*/

public ArrayList> computeCenterData(ArrayList> dataSet, int k, int dataDimension) {

ArrayList> res = new ArrayList>();

for(int i = 0; i < k; i++) {

int ClassNum = 0;

ArrayList tmp = new ArrayList();

for(int l = 0; l < dataDimension; l++) {

tmp.add(0d);

}

for(int j = 0; j < dataSet.size(); j++) {

if(dataSet.get(j).get(dataDimension) == i) {

ClassNum++;

for(int m = 0; m < dataDimension; m++) {

tmp.set(m, tmp.get(m) + dataSet.get(j).get(m));

}

}

}

for(int l = 0; l < dataDimension; l++) {

tmp.set(l, tmp.get(l) / (double)ClassNum);

}

res.add(tmp);

}

return res;

}

/**

*

* @Title: computeChange

* @Description: 计算两轮迭代中心节点位置的变化量

* @return double

* @throws

*/

public double computeChange(ArrayList> dataSet, ArrayList> centerData, int k, int dataDimension) {

double variance = 0;

ArrayList> originalCenterData = computeCenterData(dataSet, k, dataDimension);

for(int i = 0; i < centerData.size(); i++) {

variance += computeDistance(originalCenterData.get(i), centerData.get(i));

}

return variance;

}

public static void main(String[] args) {

final int CLUSTER1_NUM = 4;

final int CLUSTER2_NUM = 4;

final int CLUSTER3_NUM = 4;

ArrayList> dataSet = new ArrayList>();

// 产生簇1

for(int i = 0; i < CLUSTER1_NUM; i++) {

ArrayList cluster1 = new ArrayList();

cluster1.add(1 + Math.random() * 2);

cluster1.add(1 + Math.random() * 2);

dataSet.add(cluster1);

}

// 产生簇2

for(int i = 0; i < CLUSTER2_NUM; i++) {

ArrayList cluster2 = new ArrayList();

cluster2.add(Math.random());

cluster2.add(Math.random());

dataSet.add(cluster2);

}

// 产生簇3

for(int i = 0; i < CLUSTER3_NUM; i++) {

ArrayList cluster3 = new ArrayList();

cluster3.add(3 + Math.random());

cluster3.add(3 + Math.random());

dataSet.add(cluster3);

}

KmeansAlgorithm d = new KmeansAlgorithm();

ArrayList> dd = d.getClusters(dataSet, 3);

System.out.println(dd);

}

}

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值