java的k-means算法_K-Means 算法(Java)

kMeans算法原理见我的上一篇文章。这里介绍K-Means的Java实现方法,参考了Python的实现方法。

一、数据点的实现

package com.meachine.learning.kmeans;

import java.util.ArrayList;

/**

* 数据点,有n维数据

*

*/

public class Point {

private static int num;

private int id;

private int dimensioNum; // 维度

private ArrayList values;

private int clusterId = -1;

private double minDist = Integer.MAX_VALUE;

public Point() {

id = ++num;

values = new ArrayList<>();

}

public void add(double e) {

values.add(e);

dimensioNum++;

}

//------set与get省略----------

}

二、数据簇的实现

package com.meachine.learning.kmeans;

import lombok.EqualsAndHashCode;

import lombok.Getter;

import lombok.Setter;

import lombok.ToString;

/**

* 簇

* 数据集合的基本信息

*

*/

public class Cluster {

// 簇id

private int clusterId;

// 属于该簇的点的个数

private int numOfPoints;

// 簇中心点的信息

private Point center;

public Cluster(int id) {

this.clusterId = id;

numOfPoints = 0;

}

public Cluster(int id, Point center) {

this.clusterId = id;

this.center = center;

}

//----------set与get省略----------------

}

三、计算数据点距离

package com.meachine.learning.kmeans;

import java.util.List;

/**

* 计算距离接口

*

*/

public interface IDistance {

public double getDis(List p1, List p2);

}

package com.meachine.learning.kmeans;

import java.util.List;

/**

* 欧式距离

*

*/

public class OujilidDistance implements IDistance {

public double getDis(List a, List b) {

if (a.size() != b.size()) {

throw new IllegalArgumentException("Size not compatible!");

}

double result = 0;

for (int i = 0; i < a.size(); i++) {

result += Math.pow((a.get(i).doubleValue() - b.get(i).doubleValue()), 2);

}

return Math.sqrt(result);

}

}

四、K-Means算法

package com.meachine.learning.kmeans;

import java.io.BufferedReader;

import java.io.File;

import java.io.FileReader;

import java.io.IOException;

import java.util.ArrayList;

import java.util.List;

import java.util.Random;

/**

* K-Means算法

*

* @author Cang

*

*/

public class KMeans {

// 簇的个数

private int k;

// 维度,即多少个变量

private int dimensioNum;

// 最大迭代次数

private int maxItrNum = 100;

private IDistance distance;

private List points;

private List clusters = new ArrayList();

private String dataFileName = "D:/testSet.txt";

public KMeans(int k) {

this.k = k;

}

/**

* 初始化数据

*/

public void init() {

points = loadDataSet(dataFileName);

distance = new OujilidDistance();

initCluster();

}

/**

* 加载数据集

*

* @param fileName

* @return

*/

private List loadDataSet(String fileName) {

List points = new ArrayList<>();

File file = new File(fileName);

BufferedReader reader = null;

try {

reader = new BufferedReader(new FileReader(file));

String tempString = null;

int i = 0;

while ((tempString = reader.readLine()) != null) {

Point point = new Point();

dimensioNum = tempString.split("\t").length;

for (String data : tempString.split("\t")) {

point.add(Double.parseDouble(data));

}

points.add(point);

}

reader.close();

} catch (IOException e) {

e.printStackTrace();

}

return points;

}

/**

* 初始化簇中心

*

* @return

*/

private void initCluster() {

Random ran = new Random();

int id = 0;

while (id < k) {

Cluster c = new Cluster(++id);

int temp = ran.nextInt(points.size());

c.setCenter(points.get(temp));

clusters.add(c);

}

}

/**

* kMeans 具体算法

*/

public void clustering() {

boolean finished = false;

int count = 0;

while (!finished) {

// 寻找最近的中心

finished = true;

for (Point point : points) {

for (Cluster cluster : clusters) {

double minLen = distance.getDis(cluster.getCenter().getValues(),

point.getValues());

// 更新最小距离

if (minLen < point.getMinDist()) {

if (cluster.getClusterId() != point.getClusterId()) {

finished = false;

point.setClusterId(cluster.getClusterId());

}

point.setMinDist(minLen);

}

}

}

System.out.println("Cluster center info:");

for (Cluster string : clusters) {

System.out.println(string.getCenter().getValues());

}

// 更改中心的位置

changeCentroids();

// 超过循环次数,则跳出循环

if (++count > maxItrNum) {

finished = true;

}

}

}

/**

* 改变簇中心

*/

private void changeCentroids() {

for (Cluster cluster : clusters) {

ArrayList newCenterValue = new ArrayList();

Point newCenterPoint = new Point();

double result = 0;

for (int i = 0; i < dimensioNum; i++) {

for (Point point : points) {

if (point.getClusterId() == cluster.getClusterId()) {

result += point.getValues().get(i);

}

}

newCenterValue.add(result / points.size());

}

newCenterPoint.setClusterId(cluster.getClusterId());

newCenterPoint.setValues(newCenterValue);

cluster.setCenter(newCenterPoint);

}

}

public static void main(String[] args) {

KMeans kmeans = new KMeans(4);

kmeans.init();

kmeans.clustering();

}

}

原文:http://www.cnblogs.com/codingexperience/p/5040942.html

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值