机器学习——K-Means++算法及实现

K-Means++算法及实现

K-Means++ Algorithm

Key points
同K-Means相比,初始化中心点方式发生变化。由于K-Means初始K cluster centers随机产生,这样算法收敛情况受初始随机化影响较大,K-Means针对初始化问题做了改进。

  • step 1:从输入样本点集合中随机选取一个点作为第一个cluster center
  • step 2:计算每一个样本点到最近cluster center的距离D(x)
  • step 3:选取新的样本点作为新的cluster center,选取规则为:D(x)值越大,则被选中的概率越大
  • step 4:重复上述两步直至K个cluster center被选出
  • step 5:选出的K个cluster center作为初始执行K-Means算法(可参考上篇对K-Means算法的介绍(https://blog.csdn.net/weixin_38250282/article/details/83589259))
  • 上述算法中step 3 中选择将D(x)反映到概率上,本文选择正比关系,即P = k * D(x)

    K-Means Cluster Center Initialization
    初始化概率生成K个cluster center过程:

    def randProb(dataSet, k, m, disMeas=disEclud):
        n = np.shape(dataSet)[1]
        centCoids = np.mat(np.zeros((k, n)))
        dis = np.ones(m)
        
        for i in range(k):
            sumDis = float(sum(dis))
            rand = sumDis * random()
            j =0
            while rand > 0:
                rand = rand - dis[j]
                j = j+1
            centCoids[i,:] = np.mat(dataSet[j-1,:])
            
            for p in range(m):
                minDis = 1e100
                #minIndex = -1
                for l in range(i+1):
                    disPL = disMeas(dataSet[p,:], centCoids.A[l,:])
                    if disPL < minDis:
                        minDis = disPL
                        #minIndex = i
                dis[p] = minDis
        
        return centCoids
    

    样例
    数据集:sklearn 中的 make_blobs 数据集,效果如下:
    使得

    Python实现

    # -*- coding: utf-8 -*-
    # dataSet m个n维测试样本
    # disMeas 距离量度,选择欧式距离
    # createInitCent 初始化k中心点,选择k-means++算法
    # clusterAssment key=样本点   [key][1] = 所属cent [key][2] = 到中心点量度
    # centCoids k中心  size = k*n 
    
    import numpy as np
    import math
    from sklearn.datasets import make_blobs
    import matplotlib.pyplot as plt
    from random import random
    from scipy import cluster
    
    def disEclud(vetA, vetB):
        return math.sqrt(sum(pow(vetA - vetB, 2)))
    
    def randProb(dataSet, k, m, disMeas=disEclud):
        n = np.shape(dataSet)[1]
        centCoids = np.mat(np.zeros((k, n)))
        dis = np.ones(m)
        
        for i in range(k):
            sumDis = float(sum(dis))
            rand = sumDis * random()
            j =0
            while rand > 0:
                rand = rand - dis[j]
                j = j+1
            centCoids[i,:] = np.mat(dataSet[j-1,:])
            
            #for j in range(m):
             #   rand = rand - dis[j]
             #   if rand <= 0:
             #       centCoids[i,:] = np.mat(dataSet[j,:])
             #       break
            
            for p in range(m):
                minDis = 1e100
                #minIndex = -1
                for l in range(i+1):
                    disPL = disMeas(dataSet[p,:], centCoids.A[l,:])
                    if disPL < minDis:
                        minDis = disPL
                        #minIndex = i
                dis[p] = minDis
        
        return centCoids
    
    def kMeans(dataSet, k, disMeas=disEclud, createInitCent=randProb):
        m = np.shape(dataSet)[0]
        clusterAssment = np.mat(np.zeros((m, 2)))
        centCoids = createInitCent(dataSet, k, m)
        for i in range(m):
            minDis = float('inf')
            minIndex = -1
            for j in range(k):
                disIJ = disMeas(dataSet[i,:], centCoids.A[j,:])
                if disIJ < minDis:
                    minDis = disIJ
                    minIndex = j
            clusterAssment[i,:] = minIndex, minDis**2
            
        for cent in range(k):
            pstInClust = dataSet[np.nonzero(clusterAssment.A[:,0] == cent)[0]]
            centCoids[cent,:] = np.mean(pstInClust, axis = 0) 
            
        return centCoids, clusterAssment   
    
    def draw(center, dataSet, clusterAssment):
        fig = plt.figure
        length = len(center)
        #plt.scatter(data[:,0], data[:,1], s=25, c='b', alpha=0.4)
        color = ['b','g','y']
        for i in range(length):
            plt.scatter(center[i,0], center[i,1], c = 'r')
            pstInClustX = dataSet[np.nonzero(clusterAssment.A[:,0] == i)[0],0]
            pstInClustY = dataSet[np.nonzero(clusterAssment.A[:,0] == i)[0],1]
            plt.scatter(pstInClustX, pstInClustY, s=25, c=color[i], alpha=0.4)
        plt.show()
        
    def main():
        X, y = make_blobs(random_state=1)
        centCoids, clusterAssment = kMeans(X, 3)
        draw(centCoids, X, clusterAssment)
        #print(centCoids)
        #print(clusterAssment)
    
    main()
    
  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值