一、环境:
Windows10,Eclipse,Java JDK1.8(1.7以上都可以)
二、数据集(部分)
iris.txt 1.4,0.2,Iris-setosa 1.4,0.2,Iris-setosa 1.3,0.2,Iris-setosa 1.5,0.2,Iris-setosa 1.4,0.2,Iris-setosa 1.7,0.4,Iris-setosa 1.4,0.3,Iris-setosa 1.5,0.2,Iris-setosa 1.4,0.2,Iris-setosa 1.5,0.1,Iris-setosa 1.5,0.2,Iris-setosa 1.6,0.2,Iris-setosa 1.4,0.1,Iris-setosa ....... ....... ....... |
三、代码
package kmeans;
import java.io.*;
import java.nio.*;
import java.nio.file.*;
import java.util.*;
public class KMeans {
//一条记录
private class Record{
int id;
double x,y;
int meanId;
Record(int id,double x,double y){
this.id = id;this.x=x;this.y=y;
}
}
//K值
private final int K=3;
//结束算法的条件,也就是新的均值点与旧的均值点所允许的最小距离
private final double DIS = 0.01;
//保存所有记录的列表
List<Record> recordList = new ArrayList<>();
//记录的总数
int recNum;
//均值点列表
List<Record> meansList = new ArrayList<>(K);
//保存聚类后的记录的列表,其本质是对recordList中记录的引用
List<List<Record>> kRecordList = new ArrayList<>(K);
/**
* 构造函数
*/
public KMeans(String file){
ReadAllRecords(file);
recNum = recordList.size();
initKMeans();
K_Means();
}
/**
* 读取所有记录行的函数
* file是文件位置
*/
private void ReadAllRecords(String file){
try {
Path path = Paths.get(file);
BufferedReader br= Files.newBufferedReader(path);
String line = "";
String []s = null;
int i=0;
while((line=br.readLine())!=null){
s = line.split(",");
Record record = new Record(i,Double.parseDouble(s[0]),
Double.parseDouble(s[1]));
recordList.add(record);
i++;
}
} catch (IOException e) {
// TODO 自动生成的 catch 块
e.printStackTrace();
}
}
/**
* 计算欧式距离
* @param mean 均值点
*/
private double Distance(Record mean,Record r){
return Math.sqrt((mean.x-r.x)*(mean.x-r.x)+(mean.y-r.y)*(mean.y-r.y));
}
/**
* 构造初始的K个均值点,并分配所有记录所属的类别
*/
private void initKMeans() {
int k = recNum/K;
//选K个均值点,其距离为k
for(int i=0;i<K;i++) {
Record r = new Record(i,recordList.get(i*k).x,recordList.get(i*k).y);
meansList.add(r);
}
//初始化kRecordsList
for(int i=0;i<K;i++) {
kRecordList.add(new ArrayList<Record>());
}
//为所有点聚类
int minMeanId;//距离点最近的距离的均值点点ID
double minDist;//最短距离
double dist;
for(int i=0;i<recNum;i++) {
minMeanId = meansList.get(0).id;
minDist = Distance(meansList.get(0),recordList.get(i));
for(int j=1;j<K;j++) {
dist = Distance(meansList.get(j),recordList.get(i));
if(dist<minDist) {
minDist = dist;
minMeanId = meansList.get(j).id;
}
}
recordList.get(i).meanId=minMeanId;
kRecordList.get(minMeanId).add(recordList.get(i));
}
}
/**
* K-Means算法函数
*/
public void K_Means(){
//循环执行300伺候强制结束算法
int n=0;
while(true) {
if(300<=n++) break;
//首先生成新的K个均值点
Record mean = null;
double distsum;
int z =0;//均值点与上次均值点相差不大的点的数量
for(int i=0;i<K;i++) {
distsum = 0;
mean = new Record(i,0,0);
for(int j=0;j<kRecordList.get(i).size();j++) {
mean.x+=kRecordList.get(i).get(j).x;
mean.y+=kRecordList.get(i).get(j).y;
}
mean.x/=kRecordList.get(i).size();
mean.y/=kRecordList.get(i).size();
if(Distance(meansList.get(i),mean)<=DIS) {
z++;
}
meansList.set(i,mean);
}
if(z==3) break;
//清空kRecordsList的所有列
for(int i=0;i<K;i++) {
kRecordList.get(i).clear();;
}
//然后重新进行聚类
int minMeanId;//距离点最近的距离的均值点点ID
double minDist;//最短距离
double dist;
for(int i=0;i<recNum;i++) {
minMeanId = meansList.get(0).id;
minDist = Distance(meansList.get(0),recordList.get(i));
for(int j=1;j<K;j++) {
dist = Distance(meansList.get(j),recordList.get(i));
if(dist<minDist) {
minDist = dist;
minMeanId = meansList.get(j).id;
}
}
recordList.get(i).meanId=minMeanId;
kRecordList.get(minMeanId).add(recordList.get(i));
}
}
}
public void ShowRecords() {
Iterator<List<Record>> iList = kRecordList.iterator();
int i=0;
while(iList.hasNext()) {
System.out.print("均值点:("+meansList.get(i).x+","+meansList.get(i).y+") 成员:");
Iterator<Record> iRecord = iList.next().iterator();
i++;
while(iRecord.hasNext()) {
Record record = iRecord.next();
System.out.print("("+record.x+","+record.y+") ");
}
System.out.println();
}
}
public static void main(String[] args) {
KMeans kmeans = new KMeans("./files/iris.txt");
kmeans.ShowRecords();
}
}