简单介绍
K-means clustering属于无监督学习(unsupervised learning)的范畴,由于设计思想易于理解,并且计算相对简单,因此实现起来较为简单,本文将采用PyTorch进行实现.
源代码以及题目等文件见博客
原理
K-means clustering主要分为两步:找到各类的中心点(记为 c e n t r o i d s 1 , c e n t r o i d s 2 , . . . , c e n t r o i d s k centroids_1,centroids_2,...,centroids_k centroids1,centroids2,...,centroidsk)和遍历所有的数据,根据其与各中心点的距离为其分类.其中K为最终的分类数,为超参数,需要使用者自行选择.
计算中心点
计算中心点的过程即是求平均的过程,假设一个类中存在n个点,那么中心点即是n个点的坐标平均值.这样做是有数学依据的,以特征值只有两个时为例,一般情况下,我们希望找到的中心点具有:中心点到各个点的几何距离和最小的特性,因此损失函数为:
l = ( x 11 − c 1 ) 2 + ( x 12 − c 2 ) 2 + ( x 21 − c 1 ) 2 + ( x 22 − c 2 ) 2 l=(x_{11}-c_1)^2+(x_{12}-c_2)^2 + (x_{21}-c_1)^2+(x_{22}-c_2)^2 l=(x11−c1)2+(x12−c2)2+(x21−c