最近几天刚刚接触机器学习,学完K-Means聚类算法。正好又赶上一个课程项目是识别“手写数字”,因为KMeans能够实现聚类,因此自然而然地想要通过KMeans来实现。
前排提示:这是kmeans聚类的一个失败案例,没有成功聚类,仅供参考。
一,什么是KMeans聚类算法??
非常传统的聚类算法,目的是将一堆数据进行分类。
它的思想很朴素:假设这里有一群点,要将这些点分成两类。要是分成的类很合理的话,那不同类之间的中心点相聚是不是应该足够大,中心点附近的同一类的点是不是应该足够多?
举个例子:
a表示的是一堆原始点,没有处理。要将a聚类成两类,先随便找到两个点,计算所有点到这两个点的距离(欧式距离,曼哈顿距离,闵式距离等等都可以),根据距离最近的原则分配成两类。这时候是不是就能够得到两类的中心点,然后再次重复操作,直到最后聚出来的类不会发生变化。
so easy 是不是
二,使用的手写数字测试集??
我们在这里使用的是mnist测试集。这家伙的知名程度在机器学习中相当于是hello world了。不知道的小伙伴可以去查查。
但是一定有人会问到,mnist测试集应该怎么通过java使用呢?
不用担心,我用Python通过TensorFlow将mnist测试集打包成了txt文件,用java的文件操作直接调用就可以了。
具体效果像这样:
这是28 * 28的二维int数组,每个值介于0到255之间,熟悉图像处理的小伙伴一定知道这是灰度值,0表示最黑,255表示最亮,因此这是黑纸白字的测试集,大家要是自己写测试数据的使用要记着对图片进行预处理,要不然可能会出错。
我将txt命名为:数字名-标号的形式,方便之后训练和测试。
三,java手撕KMeans算法
先摆上一个算法流程图
1.首先定义:
训练图片(50000 * 28 * 28 的三维数组)
聚类中心(10 * 28 * 28的三维数组)
每张图片到聚类中心的距离(50000 * 10 的二维数组)
旧的类和新的类(ArrayList[] 数组,因为不知道一个类中到底会有多少个图片)
static float[][][] num = new float[50000][28][28];
static float[][][] center = new float[10][28][28];// 聚类中心
static long[][] distance = new long[num.length][10];
static ArrayList<Integer>[] oldKinds = new ArrayList[10];// 旧的聚类
static ArrayList<Integer>[] newKinds = new ArrayList[10];
2.定义方法:
从Txt文件导入测试数据的方法
public static void getTXT(String path,int img,int x,int y) throws IOException {
File file = new File(path);
FileInputStream fis = new FileInputStream(file);
InputStreamReader isr = new InputStreamReader(fis);
BufferedReader br = new BufferedReader(isr);
String line;
while((line = br.readLine()) != null){
boolean isNum = false;
for(int i = 0;i < line.length();i ++){
if(line.charAt(i) != ' ' && !isNum){
// 如果遇到数字
isNum = true;
float tempNum = 0;
// 取数字
while(i < line.length() && line.charAt(i) != ' '){
tempNum = tempNum * 10 + line.charAt(i) - '0';
i++;
}
isNum = false;
if(y < 28){
}
else{
y = 0;
x ++;
}
num[img][x][y] = tempNum;
y++;
}
}
}
br.close();
}
获得图片到聚类中心距离的方法