与分类、序列标注等任务不同,聚类是在事先并不知道任何样本标签的情况下,通过数据之间的内在关系把样本划分为若干类别,使得同类别样本之间的相似度高,不同类别之间的样本相似度低(即增大类内聚,减少类间距)。
聚类属于非监督学习,K均值聚类是最基础常用的聚类算法。它的基本思想是,通过迭代寻找K个簇(Cluster)的一种划分方案,使得聚类结果对应的损失函数最小。
通俗来说,聚类就是把一些点(以二维坐标系为例)选取适当的中心将这些点进行分类。
那么我们如何使用Java来实现该算法的实现,这里以一些随机生成的二维坐标轴点为例。
首先,我们在IDEA下新建软件包,KMeansTest,然后新建类KMeans和类KMeansMethod以及文本文件result用于存放生成的点以及对应的聚簇中心点。
在KMeans类中编写程序的主函数和主运行逻辑,KMeansMethod中具体实现主函数中需要用的方法。
首先,我们需要获取对应的随机生成的点集
编写生成点集的函数前,我们需要编写代码用于生成随机点,这里我们为了方便,将每个点的x与y值设置为不大于100的正整数。
生成随机点的函数getInts()代码如下
// 生成随机点
private static int[] getInts(){
int[] ints = new int[2];
Random random = new Random();
ints[0] = random.nextInt(100);
ints[1] = random.nextInt(100);
return ints;
}
生成随机点集的函数getPoints(int n)的代码如下,n为生成点的数量
// 获取点的集合
public static int[][] getPoints(int n) {
int[][] points = new int[n][];
for (int i = 0; i < n; i++) {
points[i] = getInts();
}
return points;
}
至此,我们已经成功获得需要查找聚簇中心的点集,现在我们需要获取这些点的聚簇中心,编写函数getRandomPoints(int k) k为所需聚簇中心的数量,代码如下
// 生成随机聚簇中心
public static int[][] getRandomPoints(int k) {
int[][] ints = new int[k][];
for (int i = 0; i < k; i++) {
ints[i] = getInts();
}
return ints;
}
已经获取点集,并且我们已经获取到了随机生成的聚簇中心点集,下面我们需要将随机生成的点分配到每一个聚簇中心上,我们需要一个计算距离的函数和一个输出被分配聚簇中心点的函数。
计算欧几里得距离的函数getDistance(int[] n1, int[] n2)如下,n1和n2分别是需要求解距离的点。
// 计算欧几里得距离
public static double getDistance(int[] n1, int[] n2) {
return sqrt(pow((n1[0] - n2[0]), 2) + pow((n1[1] - n2[1]), 2));
}
返回适配聚簇中心的函数getClusterCenter(int[][] clusterPoints, int[] point)如下,其中clusterPoints为聚簇中心点的点集,point为该点
// 返回某一个点最近的聚类中心
public static int[] getClusterCenter(int[][] clusterPoints, int[] point){
double minValue = Integer.MIN_VALUE;
int minIndex = 0;
for (int i = 0; i < clusterPoints.length; i++) {
if (getDistance(clusterPoints[i], point) > minValue) {
minValue = getDistance(clusterPoints[i], point);
minIndex = i;
}
}
return clusterPoints[minIndex];
}
此时,我们已经成功编写KMeansMethod类的所有数学方法,接下来我们需要在KMeans类中编写该项目的主要业务逻辑。
该项目的主要业务逻辑为,用户指定测试随机点的数量,并且指定聚簇中心数量,然后由程序将所有生成的随机点及其对应的聚簇中心写入文件result,并且在控制台输出聚簇中心。
KMeans.java中的代码如下
package TestExample.KMeansTest;
import java.io.*;
import java.util.*;
import static TestExample.KMeansTest.KMeansMethod.*;
/**
* @author loop
* @version 1.0
*/
public class KMeans {
public static void main(String[] args) throws IOException {
Scanner sc = new Scanner(System.in);
System.out.println("请输入模拟的点数");
int[][] points = getPoints(sc.nextInt());
System.out.println("选取K个样本作为聚簇核心,K <= " + points.length);
int k = sc.nextInt();
if (k > points.length){
System.out.println("输入错误");
}else {
int[][] clusterPoints = getRandomPoints(k);
FileWriter writer = new FileWriter("src/TestExample/KMeansTest/result");
writer.write("原始点 聚簇点\n");
Set<int[]> clusters = new HashSet<>();
for (int[] i : points) {
int[] cluster = getClusterCenter(clusterPoints, i);
clusters.add(cluster);
writer.write(Arrays.toString(i) + " " + Arrays.toString(cluster) + "\n");
}
System.out.println("聚簇中心分别为:");
for (int[] i : clusters) {
System.out.println(Arrays.toString(i));
}
writer.flush();
writer.close();
}
}
}
KMeansMethod.java的代码如下,其中被注释的是优化前的部分
package TestExample.KMeansTest;
import java.util.*;
import static java.lang.Math.pow;
import static java.lang.Math.sqrt;
/**
* @author loop
* @version 1.0
*
*/
public class KMeansMethod {
// 返回某一个点最近的聚类中心
public static int[] getClusterCenter(int[][] clusterPoints, int[] point){
double minValue = Integer.MIN_VALUE;
int minIndex = 0;
for (int i = 0; i < clusterPoints.length; i++) {
if (getDistance(clusterPoints[i], point) > minValue) {
minValue = getDistance(clusterPoints[i], point);
minIndex = i;
}
}
return clusterPoints[minIndex];
}
// 生成随机聚簇中心
public static int[][] getRandomPoints(int k) {
int[][] ints = new int[k][];
for (int i = 0; i < k; i++) {
ints[i] = getInts();
}
return ints;
}
// public static int[][] getRandomPoints(int[][] points, int k) {
// int[][] randomPoint = new int[k][];
// ArrayList<Integer> list = getDistinct(points.length, k);
// for (int i = 0; i < k; i++) {
// randomPoint[i] = points[list.get(i)];
// }
// return randomPoint;
// }
//
// // 生成不重复随机数
// private static ArrayList<Integer> getDistinct(int len, int k) {
// ArrayList<Integer> list = new ArrayList<>();
// Random random = new Random();
// for (int i = 0; i < k; i++) {
// int order = random.nextInt(len);
// if (!list.contains(order)){
// list.add(order);
// }else {
// i --;
// }
// }
// return list;
// }
// 计算欧几里得距离
public static double getDistance(int[] n1, int[] n2) {
return sqrt(pow((n1[0] - n2[0]), 2) + pow((n1[1] - n2[1]), 2));
}
// 获取点的集合
public static int[][] getPoints(int n) {
int[][] points = new int[n][];
for (int i = 0; i < n; i++) {
points[i] = getInts();
}
return points;
}
// 生成随机点
private static int[] getInts(){
int[] ints = new int[2];
Random random = new Random();
ints[0] = random.nextInt(100);
ints[1] = random.nextInt(100);
return ints;
}
}
对此程序进行测试
发现控制台输出正常,同时我们查看result文件,发现点集以及聚簇中心成功写入
则该程序成功运行
注:本次测试所采用的点为随机生成,具体操作时需要按照自己的需求修改代码。