//使用Point类中的flag属性保存所属聚类中心的属性,第二种方式直接使用集合存储聚类中心及该中心的点集
import java.util.Random;
import java.util.Scanner;
class PPoint{
public float x;
public float y;
public int flag = -1;
public PPoint(){
}
public PPoint(float x,float y){
this.x = x;
this.y = y;
}
}
public class Test {
PPoint pc[] = null;
PPoint pcore[] = null;
PPoint pcoren[] = null;
public void init(){
Scanner sc = new Scanner(System.in);
System.out.println("请输入生成随机点个数");
int num = sc.nextInt();
pc = new PPoint[num];
//防止生成重复点
float x0 = new Random().nextInt(10);
float y0 = new Random().nextInt(10);
pc[0] = new PPoint();
pc[0].x = x0;
pc[0].y = y0;
for(int i=1;i<num;i++){
int flag = 0;
float x = new Random().nextInt(10);
float y = new Random().nextInt(10);
for(int j=0;j<i;j++){
if(pc[j].x == x && pc[j].y == y){
flag = 1;
break;
}
}
if(flag == 1){
i--;
}else{
pc[i] = new PPoint();
pc[i].x = x;
pc[i].y = y;
}
}
System.out.println("请输入聚类中心个数");
int core = sc.nextInt();
pcore = new PPoint[core];
pcoren = new PPoint[core];
//防止生成重复中心
int temp[] = new int[core];
temp[0] = new Random().nextInt(num);
pcore[0] = new PPoint();
pcore[0].x = pc[temp[0]].x;
pcore[0].y = pc[temp[0]].y;
for(int i=1;i<core;i++){
int flag = 0;
int tempRandom = new Random().nextInt(num);
for(int j=0;j<i;j++){
if(temp[j]==tempRandom){
flag = 1;
break;
}
}
if(flag == 1){
i--;
}else{
temp[i] = tempRandom;
pcore[i] = new PPoint();
pcore[i].x = pc[tempRandom].x;
pcore[i].y = pc[tempRandom].y;
pcore[i].flag = 0; //0表示聚类中心
}
}
System.out.println("生成随机点如下:");
for(int i=0;i<num;i++){
System.out.println(pc[i].x+","+pc[i].y);
}
System.out.println("生成聚类中心如下");
for(int i=0;i<pcore.length;i++){
System.out.println("<"+pcore[i].x+","+pcore[i].y+">");
}
}
public void moveCore(){
searchBelong();
calAverage();
double moveDist = 0;
int flag = 0;
for(int i=0;i<pcore.length;i++){
flag = 0;
moveDist = distPPoint(pcore[i], pcoren[i]);
if(moveDist > 0.01){
flag = 1;
break;
}
}
if(flag == 0){
System.out.println("迭代完毕");
}else{
copyCore(pcore,pcoren);
moveCore();
}
}
public void copyCore(PPoint[] oldCore,PPoint[] newCore){
for(int i=0;i<pcore.length;i++){
oldCore[i].x = newCore[i].x;
oldCore[i].y = newCore[i].y;
oldCore[i].flag = 0;
}
}
public void searchBelong(){
for(int i=0;i<pc.length;i++){
double dist = 999;
int label = -1;
for(int j=0;j<pcore.length;j++){
double distance = distPPoint(pc[i],pcore[j]);
if(distance < dist){
dist = distance;
label = j;
}
}
pc[i].flag = label + 1;
}
}
public double distPPoint(PPoint i,PPoint j){
return Math.sqrt(Math.pow(i.x - j.x, 2) + Math.pow(i.y - j.y,2));
}
public void calAverage(){
for(int i=0;i<pcore.length;i++){
System.out.println("属于<"+pcore[i].x+","+pcore[i].y+">的点有:");
float lengthX = 0;
float lengthY = 0;
int number = 0;
for(int j=0;j<pc.length;j++){
if(pc[j].flag == (i+1)){
System.out.println(pc[j].x+","+pc[j].y);
lengthX += pc[j].x;
lengthY += pc[j].y;
number++;
}
}
pcoren[i] = new PPoint();
pcoren[i].x = lengthX / number;
pcoren[i].y = lengthY / number;
pcoren[i].flag = 0;
System.out.println("新的聚类中心为<"+pcoren[i].x+","+pcoren[i].y+">");
}
}
public static void main(String[] args) {
// TODO Auto-generated method stub
Test test = new Test();
test.init();
test.moveCore();
}
}
//这种方式使用集合存储聚类中心和点集,从文件中读取点集
package test;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Scanner;
import java.util.Set;
import org.apache.commons.io.IOUtils;
public class Test {
private static final double maxDistance = 1.0e-9;
private static List<Point> allPoint = new ArrayList<>(); //存储所有点集
private static Map<Point, List<Point>> map = new HashMap<>(); //存储聚类中心和属于该类的点(与Point中用flag标识不同,这里使用Map结构存储中心里的点)
private static Map<Point,Point> replaceOldPoint = new HashMap<>(); //key表示新的聚类中心,value表示旧的中心
public static void main(String[] args) throws IOException {
// TODO Auto-generated method stub
Scanner sc = new Scanner(System.in);
System.out.println("请输入聚类中心数目");
int n = sc.nextInt();
long startTime = System.currentTimeMillis();
//读取文件中所有点放入allPoint
readAllPoint();
//首次随机指定聚类中心
randomClusterCenter(n);
//KMeans算法实现(当中心移动距离小于阈值break)
for(int i=0;;i++){
//判断每个点属于哪个中心,新的聚类中心保存在replaceOldPoint的key中
for(Point p : allPoint){
Point center = judge(p);
if(map.get(center)==null){
List pList = new ArrayList<>();
pList.add(p);
map.put(center, pList);
}else{
List pList = map.get(center);
pList.add(p);
map.put(center, pList);
}
}
printResult(i+1);
//从map中取得旧的聚类中心,并生成新的中心
replaceOldPoint.clear();
for(Point p : map.keySet()){
double totalX = 0;
double totalY = 0;
List<Point> list = map.get(p);
for(Point pp : list){
totalX += pp.getX();
totalY += pp.getY();
}
replaceOldPoint.put(new Point(totalX / list.size(),totalY / list.size()), p);
}
map.clear();
if(maxMoveDistance() < maxDistance)
break;
}
long endTime = System.currentTimeMillis();
System.out.println("耗时:"+(endTime-startTime));
}
public static void readAllPoint() throws IOException{
FileReader reader = new FileReader(new File("src/k-means_test.txt"));
//利用commons-io.jar把文件内容存放到字符串中
String s = IOUtils.toString(reader);
String[] split = s.split(" ");
for(String line : split){
//去掉括号
line = line.replaceAll("[\\(\\)]+", "");
String[] fields = line.split(",");
Point point = new Point();
point.setX(Double.parseDouble(fields[0]));
point.setY(Double.parseDouble(fields[1]));
allPoint.add(point);
}
reader.close();
}
public static void randomClusterCenter(int n){
List<Integer> list = new ArrayList<>();
while(list.size() < n){
int random;
//确保不生成重复的聚类中心
do{
random = (new Random().nextInt(allPoint.size()));
}while(list.contains(random));
list.add(random);
}
for(int number : list){
Point old = new Point(Double.MAX_VALUE,Double.MAX_VALUE);
replaceOldPoint.put(allPoint.get(number), old);
}
}
public static Point judge(Point p){
double dist = Double.MAX_VALUE;
Point flagPoint = null;
for(Point pp : replaceOldPoint.keySet()){
double distance = getDistance(p,pp);
if(distance < dist){
dist = distance;
flagPoint = pp;
}
}
return flagPoint;
}
public static double getDistance(Point p,Point pp){
return Math.sqrt(Math.pow(p.getX()-pp.getX(), 2) + Math.pow(p.getY()-pp.getY(), 2));
}
public static void printResult(int i){
System.out.println("第"+i+"次聚类结果");
String s = "聚类中心为:"+replaceOldPoint.keySet();
System.out.println(s);
for(Point p : replaceOldPoint.keySet()){
System.out.println("属于中心<"+p.getX()+","+p.getY()+">的点:"+map.get(p));
}
}
public static double maxMoveDistance(){
double flagDistance = Double.MIN_VALUE;
Set<Point> ps = replaceOldPoint.keySet();
for(Point p : ps){
flagDistance = getDistance(p, replaceOldPoint.get(p)) > flagDistance ? getDistance(p, replaceOldPoint.get(p)) : flagDistance;
}
return flagDistance;
}
}
点集文件:
点击打开链接