机器学习:k-means算法

K-means是最经典的基于划分的聚类方法,通过迭代更新聚类中心找到最佳聚类。其优点是简单快速,但对初始中心敏感且需预设聚类数。适用于已知聚类数目的场景。
摘要由CSDN通过智能技术生成

描述

算法接受参数 k ;然后将事先输入的n个数据对象划分为 k个聚类以便使得所获得的聚类满足:同一聚类中的对象相似度较高;而不同聚类中的对象相似度较小。聚类相似度是利用各聚类中对象的均值所获得一个“中心对象”(引力中心)来进行计算的。

K-means算法是最为经典的基于划分的聚类方法,是十大经典数据挖掘算法之一。K-means算法的基本思想是:以空间中k个点为中心进行聚类,对最靠近他们的对象归类。通过迭代的方法,逐次更新各聚类中心的值,直至得到最好的聚类结果。
该算法的最大优势在于简洁和快速。算法的关键在于初始中心的选择和距离公式。

聚类

聚类分析是一种静态数据分析方法,常被用于机器学习,模式识别,数据挖掘等领域。通常认为,聚类是一种无监督式的机器学习方法,它的过程是这样的:在未知样本类别的情况下,通过计算样本彼此间的距离(欧式距离,马式距离,汉明距离,余弦距离等)来估计样本所属类别。从结构性来划分,聚类方法分为自上而下和自下而上两种方法,前者的算法是先把所有样本视为一类,然后不断从这个大类中分离出小类,直到不能再分为止;后者则相反,首先所有样本自成一类,然后不断两两合并,直到最终形成几个大类。 

聚类算法

常用的聚类方法主要有以下四种:

  • Connectivity based clustering  (如hierarchical clustering 层次聚类法)
  • Centroid-based clustering  (如kmeans)
  • Distribution-based clustering
  • Density-based clustering

优缺点

Kmeans聚类是一种自下而上的聚类方法,它的优点是简单、速度快;缺点是聚类结果与初始中心的选择有关系,且必须提供聚类的数目。Kmeans的第二个缺点是致命的,因为在有些时候,我们不知道样本集将要聚成多少个类别,这种时候kmeans是不适合的,推荐使用hierarchical 或meanshift来聚类。第一个缺点可以通过多次聚类取最佳结果来解决。

算法流程

首先从n个数据对象任意选择 k 个对象作为初始聚类中心;

而对于所剩下其它对象,则根据它们与这些聚类中心的相似度(距离),分别将它们分配给与其最相似的(聚类中心所代表的)聚类;

然后再计算每个所获新聚类的聚类中心(该聚类中所有对象的均值);

不断重复这一过程直到标准测度函数开始收敛为止。

一般都采用均方差作为标准测度函数.

k个聚类具有以下特点:各聚类本身尽可能的紧凑,而各聚类之间尽可能的分开。

输入:k, data[n];

  1. 选择k个初始中心点,例如c[0]=data[0],…c[k-1]=data[k-1];
  2. 对于data[0]….data[n], 分别与c[0]…c[k-1]比较,假定与c[i]差值最少,就标记为i;
  3. 对于所有标记为i点,重新计算c[i]={ 所有标记为i的data[j]之和}/标记为i的个数;
  4. 重复(2)(3),直到所有c[i]值的变化小于给定阈值。

源码示例

python

自编

# coding=utf-8

'''
 * k-Means算法,聚类算法;  
 * 实现步骤:1. 首先是随机获取总体中的K个元素作为总体的K个中心;  
 * 2. 接下来对总体中的元素进行分类,每个元素都去判断自己到K个中心的距离,并归类到最近距离中心去;  
 * 3. 计算每个聚类的平均值,并作为新的中心点  
 * 4. 重复2,3步骤,直到这k个中线点不再变化(收敛了),或执行了足够多的迭代  
'''

from numpy import linalg as ll
import numpy as np
import random
from sklearn import cluster as sc

# train data
srcdata = [ [ 5.1, 3.5, 1.4, 0.2],
            [ 4.9, 3.0, 1.4, 0.2 ],[ 4.7, 3.2, 1.3, 0.2 ],
            [ 4.6, 3.1, 1.5, 0.2 ],[ 5.0, 3.6, 1.4, 0.2 ],
            [ 7.0, 3.2, 4.7, 1.4 ],[ 6.4, 3.2, 4.5, 1.5 ],
            [ 6.9, 3.1, 4.9, 1.5 ],[ 5.5, 2.3, 4.0, 1.3 ],
            [ 6.5, 2.8, 4.6, 1.5 ],[ 5.7, 2.8, 4.5, 1.3 ],
            [ 6.5, 3.0, 5.8, 2.2 ],[ 7.6, 3.0, 6.6, 2.1 ],
            [ 4.9, 2.5, 4.5, 1.7 ],[ 7.3, 2.9, 6.3, 1.8 ],
            [ 6.7, 2.5, 5.8, 1.8 ],[ 6.9, 3.1, 5.1, 2.3 ] ]

#print srcdata

srclen = len(srcdata)

len0 = len(srcdata[0])

# set k = 5
k = 5

# step 1

# index = range(k) -- use first k

# use random to get index.
index = random.sample( range( srclen ) , k)

#print index

c = np.arange(k*len0,dtype=float).reshape(k,len0)

for i in range(k):
    for j in range(len0):
        c[i][j] = srcdata[index[i]][j]

print c


# set stop threshold

delta = 0.001

dis = np.arange(k,dtype=float)

calindex = np.arange(srclen)

norm = 10.0

# step 4

while norm>delta:

    # step 2

    for i in range( srclen ):

        for m in range(k):
            dis[m] = 0.0

        for j in range(len0):
            for m in range(k):

                dis[m] += (srcdata[i][j] - c[m][j])*(srcdata[i][j] - c[m][j])

        calindex[i] = np.where(dis == np.min(dis) )[0][0]

    # step 3

    newc = np.zeros(k*len0).reshape(k,len0)

    for xx in range(srclen):
        for j in range(len0):
            newc[calindex[xx]][j] += srcdata[xx][j]
    for i in range(k):
        size = len(np.where(calindex == i)[0])
        #print size
        for j in range(len0):
            newc[i][j]/=size

    norm = ll.norm(c-newc)

    #print norm

    c = newc.copy()

print c

print calindex

## sklearn -> k-means++

cc = sc.k_means(srcdata,k)

print cc

Java

来源

import java.io.BufferedReader;  
import java.io.FileInputStream;  
import java.io.IOException;  
import java.io.InputStreamReader;  
import java.util.ArrayList;  
import java.util.List;  
import java.util.Random;  

public class Kmeans {
     

    /** 
     * @param args 
     * @throws IOException 
     */  

    public static List<ArrayList<ArrayList<Double>>>   
    initHelpCenterList(List<ArrayList<ArrayList<Double>>> helpCenterList,int k){  
        for(int i=0;i<k;i++){  
            helpCenterList.add(new ArrayList<ArrayList<Double>>());  
        }     
        return helpCenterList;  
    }  

    /** 
     * @param args 
     * @throws IOException 
     */  
    public static void main(String[] args) throws IOException{  

        List<ArrayList<Double>> centers = new ArrayList<ArrayList<Double>>();  
        List<ArrayList<Double>> newCenters = new ArrayList<ArrayList<Double>>();  
        List<ArrayList<ArrayList<Double>>> helpCenterList = new ArrayList<ArrayList<ArrayList<Double>>>();  

        //读入原始数据  
        BufferedReader br=new BufferedReader(new InputStreamReader(new FileInputStream("wine.txt")));  
        String data = null;  
        List<ArrayList<Double>> dataList = new ArrayList<ArrayList<Double>>();  
        while((data=br.readLine())!=null){  
            //System.out.println(data);  
            String []fields = data.split(",");  
            List<Double> tmpList = new ArrayList<Double>();  
            for(int i=0; i<fields.length;i++)  
                tmpList.add(Double.parseDouble(fields[i]));  
            dataList.add((ArrayList<Double>) tmpList);  
        }  
        br.close();  

        //随机确定K个初始聚类中心  
        Random rd = new Random();  
        int k=3;  
        int [] initIndex={
  59,71,48};  
        int [] helpIndex = {
  0,59,130};  
        int [] givenIndex = {
  0,1,2};  
        System.out.println("random centers' index");  
        for(int i=0;i<k;i++){  
            int index = rd.nextInt(initIndex[i]) + helpIndex[i];  
            //int index = givenIndex[i];  
            System.out.println("index "+index);  
            centers.add(dataList.get(index));  
            helpCenterList.add(new ArrayList<ArrayList<Double>>());  
        }     

        /* 
        //注释掉的这部分目的是,取测试数据集最后稳定的三个类簇的聚类中心作为初始聚类中心 
        centers = new ArrayList<ArrayList<Double>>(); 
        for(int i=0;i<59;i++) 
            helpCenterList.get(0).add(dataList.get(i)); 
        for(int i=59;i<130;i++) 
            helpCenterList.get(1).add(dataList.get(i)); 
        for(int i=130;i<178;i++) 
            helpCenterList.get(2).add(dataList.get(i)); 
        for(int i=0;i<k;i++){ 

            ArrayList<Double> tmp = new ArrayList<Double>(); 

            for(int j=0;j<dataList.get(0).size();j++){ 
                double sum=0; 
                for(int t=0;t<helpCenterList.get(i).size();t++) 
                    sum+=helpCenterList.get(i).get(t).get(j); 
                tmp.add(sum/helpCenterList.get(i).size()); 
            } 
            centers.add(tmp); 
        } 
        */  

        //输出k个初始中心  
        System.out.println("original centers:");  
        for(int i=0;i<k;i++)  
            System.out.println(centers.get(i));  

        while(true)  
        {
  //进行若干次迭代,直到聚类中心稳定  

            for(int i=0;i<dataList.size();i++){
  //标注每一条记录所属于的中心  
                double minDistance=99999999;  
                int centerIndex=-1;  
                for(int j=0;j<k;j++){
  //离0~k之间哪个中心最近  
                    double currentDistance=0;  
                    for(int t=1;t<centers.get(0).size();t++){
  //计算两点之间的欧式距离  
                        currentDistance +=  ((centers.get(j).get(t)-dataList.get(i).get(t))/(centers.get(j).get(t)+dataList.get(i).get(t))) * ((centers.get(j).get(t)-dataList.get(i).get(t))/(centers.get(j).get(t)+dataList.get(i).get(t)));   
                    }  
                    if(minDistance>currentDistance){  
                        minDistance=currentDistance;  
                        centerIndex=j;  
                    }  
                }  
                helpCenterList.get(centerIndex).add(dataList.get(i));  
            }  

        //  System.out.println(helpCenterList);  

            //计算新的k个聚类中心  
            for(int i=0;i<k;i++){  

                ArrayList<Double> tmp = new ArrayList<Double>();  

                for(int j=0;j<centers.get(0).size();j++){  
                    double sum=0;  
                    for(int t=0;t<helpCenterList.get(i).size();t++)  
                        sum+=helpCenterList.get(i).get(t).get(j);  
                    tmp.add(sum/helpCenterList.get(i).size());  
                }  

                newCenters.add(tmp);  

            }  
            System.out.println("\nnew clusters' centers:\n");  
            for(int i=0;i<k;i++)  
                System.out.println(newCenters.get(i));  
            //计算新旧中心之间的距离,当距离小于阈值时,聚类算法结束  
            double distance=0;  

            for(int i=0;i<k;i++){  
                for(int j=1;j<centers.get(0).size();j++){
  //计算两点之间的欧式距离  
                    distance += ((centers.get(i).get(j)-newCenters.get(i).get(j))/(centers.get(i).get(j)+newCenters.get(i).get(j))) * ((centers.get(i).get(j)-newCenters.get(i).get(j))/(centers.get(i).get(j)+newCenters.get(i).get(j)));   
                }  
                //System.out.println(i+" "+distance);  
            }  
            System.out.println("\ndistance: "+distance+"\n\n");  
            if(distance==0)//小于阈值时,结束循环  
                break;  
            else//否则,新的中心来代替旧的中心,进行下一轮迭代  
            {  
                centers = new ArrayList<ArrayList<Double>>(newCenters);  
                //System.out.println(newCenters);  
                newCenters = new ArrayList<ArrayList<Double>>();  
                helpCenterList = new ArrayList<ArrayList<ArrayList<Double>>>();  
                helpCenterList=initHelpCenterList(helpCenterList,k);  
            }  
        }  
        //输出最后聚类结果  
        for(int i=0;i<k;i++){  
            System.out.println("\n\nCluster: "+(i+1)+"   size: "+helpCenterList.get(i).size()+" :\n\n");  
            for(int j=0;j<helpCenterList.get(i).size();j++)  
            {  
                System.out.println(helpCenterList.get(i).get(j));  
            }  
        }  
    }  
}  

test

测试数据集来源:wine数据集from UCI

以下是某次的运行结果,可以看出聚类结果与参考结果比较吻合

random centers' index  
index 4  
index 102  
index 166  
original centers:  
[1.0, 13.744745762711865, 2.0106779661016954, 2.455593220338984, 17.037288135593222, 106.33898305084746, 2.8401694915254234, 2.982372881355932, 0.29, 1.8993220338983055, 5.528305084745763, 1.0620338983050848, 3.1577966101694916, 1115.7118644067796]  
[2.0, 12.278732394366198, 1.932676056338028, 2.244788732394365, 20.238028169014086, 94.54929577464789, 2.2588732394366198, 2.080845070422536, 0.363661971830986, 1.6302816901408452, 3.08661971830986, 1.0562816901408452, 2.785352112676055, 519.5070422535211]  
[3.0, 13.153749999999997, 3.3337500000000007, 2.4370833333333333, 21.416666666666668, 99.3125, 1.6787500000000002, 0.7814583333333331, 0.44749999999999995, 1.1535416666666667, 7.396249979166668, 0.6827083333333334, 1.6835416666666658, 629.8958333333334]  

new clusters' centers:  

[1.040983606557377, 13.695245901639344, 1.9844262295081967, 2.446475409836067, 17.199999999999996, 106.13934426229508, 2.8595901639344263, 3.0022950819672127, 0.28737704918032797, 1.900901639344262, 5.5009836065573765, 1.062131147540984, 3.1613934426229497, 1101.0]  
[2.0, 12.272352941176464, 1.9286764705882358, 2.2503676470588228, 20.209558823529413, 94.72058823529412, 2.229705882352941, 2.0449999999999995, 0.3635294117647058, 1.6292647058823533, 3.0122794117647054, 1.0586176470588233, 2.7795588235294093, 517.6102941176471]  
[2.9693877551020407, 13.146530612244899, 
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值