运行类
import java.io.IOException;
import java.net.URISyntaxException;
import java.util.ArrayList;
public class KmeansRun {
public static void main(String[] args) throws IOException, URISyntaxException, InterruptedException {
long start = System.currentTimeMillis();//开始时间
int k = 3;
String dataPath = "";//输入数据
double dis ;//旧簇与新簇的差
ArrayList<ArrayList<Double>> data;
ArrayList<ArrayList<Double>> center;//簇中心
ArrayList<ArrayList<Double>> newcenter ;//新的簇中心
//读取数据
data = read.readCSV(dataPath);
//计算 center
center = CenterRandom.centerRandomChoice(data,k);
//带有簇标记的data
ArrayList<DataWithIndex.dataWithIndex> dataindex ;
//循环计算新的中心点,直到中心点不再改变
for (int i = 0; i < 1000; i++) {
System.out.println("--------------------------"+i+"----------------------------------");
//分配 data
dataindex = DistributionData.distributData(data,center);
//计算新的 center
newcenter = CalCenters.calCenters(dataindex,k);
//计算新簇中心与旧簇中心的差
dis = CalUtil.calDistanceBetweenCenters(newcenter,center);
System.out.println("------------------------dis----------------------------");
System.out.println(dis);
//如果差为0,退出循环
if(dis ==0){
break;
}
//保存新簇,下个循环新簇就变成旧簇了
center = newcenter;
}
//最终的聚类结果
dataindex = DistributionData.distributData(data,center);
System.out.println(dataindex);
long end = System.currentTimeMillis();//结束时间
System.out.println("共耗时"+(end-start)+"毫秒");//耗时
}
}
创建带有簇标签的数据存储方式
import java.util.ArrayList;
//带有簇标记的数据
public class DataWithIndex {
static class dataWithIndex{
ArrayList<Double> data;//数据
int index;//簇标记
public ArrayList<Double> getData() {
return data;
}
public int getIndex() {
return index;
}
public void setData(ArrayList<Double> data) {
this.data = data;
}
public void setIndex(int index) {
this.index = index;
}
}
public DataWithIndex() {
}
}
选取初始的随机中心类
import java.util.ArrayList;
import java.util.Random;
public class CenterRandom {
//随机生成k个中心点
public static ArrayList<ArrayList<Double>> centerRandomChoice(ArrayList<ArrayList<Double>> data, int k){
ArrayList<ArrayList<Double>> center = new ArrayList<>();
int elementsSize = data.size();
int rm;
int j;
Random random = new Random();
for (int i = 0; i < k; i++) {
rm = random.nextInt();
j = Math.abs(rm % elementsSize);
center.add(data.get(j));
}
return center;
}
}
计算新的簇中心类
import java.util.ArrayList;
public class CalCenters {
//首先计算簇标记为0的簇中心,然后计算簇标记为1,2,,,k
/*
参数
center用于返回最终的k个簇中心
add 当添加一个簇中加入一个新的点时,保存新计算的簇中心
centerone 保存当前某一个簇中心
*/
public static ArrayList<ArrayList<Double>> calCenters(ArrayList<DataWithIndex.dataWithIndex> dataindex,int k){
ArrayList<ArrayList<Double>> center = new ArrayList<>();
ArrayList<Double> add ;
for (int n = 0; n < k; n++) {
ArrayList<Double> centerone = new ArrayList<>();
int count = 0;//计数,计算该类中有多少个点
//遍历dataindex,如果centerindex里没有簇标记,就添加,否则就更新centerindex.data
for (DataWithIndex.dataWithIndex one : dataindex) {
if (one.getIndex() == n) {
if (centerone.size() == 0) {
centerone.addAll(one.getData());
//centerone = one.getData();
count++;
} else {
add = CalUtil.addElement(centerone, one.getData());
centerone = add;
count++;
}
}
}
for (int i = 0; i < centerone.size(); i++) {
centerone.set(i,centerone.get(i)/count);
}
center.add(centerone);
}
return center;
}
}
计算新的簇标签
import java.util.ArrayList;
public class DistributionData {
//计算数据属于哪个簇
public static ArrayList<DataWithIndex.dataWithIndex> distributData(ArrayList<ArrayList<Double>> data, ArrayList<ArrayList<Double>> center){
ArrayList<DataWithIndex.dataWithIndex> dataindex = new ArrayList<>();
double dis;
int index = 0;
for (int i = 0; i < data.size(); i++) {
DataWithIndex.dataWithIndex onedataindex = new DataWithIndex.dataWithIndex();
double min = 10000.0;
for (int j = 0; j < center.size(); j++) {
dis = CalUtil.calDistance(data.get(i),center.get(j));
if(dis<min){
min = dis;
index = j;
}
}
onedataindex.setData(data.get(i));
onedataindex.setIndex(index);
dataindex.add(onedataindex);
}
return dataindex;
}
}
各种数值计算类
import java.util.ArrayList;
// 计算工具类,两值距离,选择最近中心点等
public class CalUtil {
// 计算两向量距离,欧式
public static double calDistance(ArrayList<Double> element1, ArrayList<Double> element2){
double disSum = 0;
for(int i=0;i<element1.size();i++){
disSum += (element1.get(i) - element2.get(i)) * (element1.get(i) - element2.get(i));
}
return Math.sqrt(disSum);
}
// 元素相加
public static ArrayList<Double> addElement(ArrayList<Double> element1, ArrayList<Double> element2){
for(int i=0;i<element1.size();i++) {
element1.set(i, element1.get(i) + element2.get(i));
}
return element1;
}
//主要是利用CalUtil.calDistanceBetweenCenters计算新旧两组中心点之间的距离差值,因为较难把控阈值信息,直接就等两组中心点完全相同时实现停机,返回true。
// 计算两次迭代的中心是否有变化,返回距离
public static double calDistanceBetweenCenters(ArrayList<ArrayList<Double>>oldCenter, ArrayList<ArrayList<Double>>newCenter){
// 因为data的读入顺序相同,所以最终收敛时聚类中心的顺序也相同
// 只要遍历计算距离即可,不用考虑中心点本身顺序
if(oldCenter.size() > newCenter.size())
return 1000;
double sum = 0;
for(int i=0;i<oldCenter.size();i++){
double singleDistance = calDistance(oldCenter.get(i), newCenter.get(i));
sum += singleDistance;
}
return sum;
}
}
文件读取类
import java.io.*;
import java.util.ArrayList;
public class read {
//读取csv文件
public static ArrayList<ArrayList<Double>> readCSV(String readFile){
ArrayList<Double> Data = new ArrayList<>();
ArrayList<ArrayList<Double>> DataS = new ArrayList<>();
try {
BufferedReader reader = new BufferedReader(new FileReader(readFile));//从哪个文件读
reader.readLine();//第一行信息,为标题信息,不用,如果需要,注释掉
String line = null;
//一行一行的读
while((line=reader.readLine())!=null){
String[] item = line.split(",");//CSV格式文件为逗号分隔符文件,这里根据逗号切分
//System.out.println(item);
for (String s : item) {
Data.add(Double.parseDouble(s));
}
DataS.add(Data);
}
} catch (Exception e) {
e.printStackTrace();
}
return DataS;
}
}