基于sklearn的k-means++

import numpy as np
from sklearn.cluster import KMeans
import matplotlib.pyplot as plt

to_array_2 = lambda data : np.array(np.squeeze(data))[0]

def kmeans_building_from_sklearn(data,types_num,types_init = "k-means++"):
    '''
    K-means 算法,不限制维度
    :param data: numpy数据结构的一维或者多维的数组
    :param types_num: K点的选择数量
    :return: (分类好的新数据,分类的中心点,分类的切片标签)

    :Example:
    data = np.mat([[0.0, 7.8031236401009165], [0.1, 4.426411521373964], [0.2, 8.947706236035394], [0.3, 8.200555461071133], [0.4, 0.5142706050449086], [0.5, 3.1008492860590753], [0.6, 7.1785370142703], [0.7, 3.872126889453009], [0.8, 1.025577102380758], [0.9, 4.833507884197839], [1.0, 0.9345186488455648], [1.1, 4.812032669803522], [1.2, 3.4191665496674375], [1.3, 0.022961590431520573], [1.4, 0.8785497842851486], [1.5, 3.2381682303766004], [1.6, 6.663617230163021], [1.7, 0.9093960835056158], [1.8, 3.680521995791508], [1.9, 4.900535080573809], [2.0, 0.5937324344804573], [2.1, 7.783916463844781], [2.2, 7.11179421378249], [2.3, 2.4446164389372083], [2.4, 0.31413070861756043], [2.5, 8.793728586574554], [2.6, 5.826793655403399], [2.7, 5.694595209349293], [2.8, 1.810999124588204], [2.9, 2.5891746519869896], [3.0, 7.989527941362203], [3.1, 3.8888255411877237], [3.2, 6.980965568743198], [3.3, 6.183466511452052], [3.4, 8.994788345820844], [3.5, 7.713561865457374], [3.6, 5.373007398355321], [3.7, 5.857041801728861], [3.8, 1.2873991273003293], [3.9, 7.90280175406894], [4.0, 2.114965228889154], [4.1, 1.9122653164064485], [4.2, 1.430294828685612], [4.3, 6.447866312224862], [4.4, 1.6034730966952893], [4.5, 5.020523561692603], [4.6, 8.434230931666896], [4.7, 8.926232142246732], [4.8, 7.21575735052221], [4.9, 0.8242497089572909], [5.0, 3.2773727950676923], [5.1, 4.789459791385374], [5.2, 5.7771008824695205], [5.3, 5.475006081618368], [5.4, 0.1095089272289862], [5.5, 3.708028945757401], [5.6, 5.868541457709577], [5.7, 3.82129152340557], [5.8, 4.672995214563882], [5.9, 7.139883221140032], [6.0, 1.9266318905278135], [6.1, 8.37752018393436], [6.2, 0.38698741212173093], [6.3, 6.809020238683873], [6.4, 6.7650503938175754], [6.5, 3.622205528236381], [6.6, 6.2275067570806755], [6.7, 9.272360909999943], [6.8, 6.937813473353321], [6.9, 6.719818328095723], [7.0, 4.1761213627699085], [7.1, 2.2186584343523554], [7.2, 4.325155968720139], [7.3, 6.346030576482492], [7.4, 6.681255604404174], [7.5, 0.9022441439624651], [7.6, 3.5585144004608313], [7.7, 3.0389210408048797], [7.8, 1.2906111157619915], [7.9, 5.5836998346201785], [8.0, 9.35888445552132], [8.1, 7.895786013311998], [8.2, 8.908751780051233], [8.3, 0.5462991244242466], [8.4, 5.430329386519083], [8.5, 5.632137779391064], [8.6, 8.919006991403911], [8.7, 1.881419639976668], [8.8, 0.8227664012604219], [8.9, 6.551076298674367], [9.0, 0.8128558657766949], [9.1, 3.8227215510033807], [9.2, 5.628006555285536], [9.3, 7.4386834204408645], [9.4, 3.7984910271263073], [9.5, 4.012712837404079], [9.6, 9.755417444517088], [9.7, 2.8443889672100973], [9.8, 4.143513529213513], [9.9, 4.093044136223663]])
    result = kmeans_building_from_sklearn(data, 10) # 本例要分3类,所以传入一个3
    x = [to_array_2(data[i][:,0:1]) for i in result[-1]]
    y = [to_array_2(data[i][:,-1:]) for i in result[-1]]
    for i,j,Kp in zip(x,y,result[1]):
        plt.scatter(i,j)
        plt.scatter(Kp[0],Kp[-1],c='black', marker='x', alpha=1.0)
    plt.show()
    plt.close()
    '''
    kmeans_model = KMeans(n_clusters=types_num, init=types_init).fit(data)
    newdata = np.c_[kmeans_model.labels_,np.array(data)]
    slice_ = [np.squeeze((newdata==i)[:,0:1]) for i in range(types_num)]
    return [newdata[i] for i in slice_],kmeans_model.cluster_centers_,slice_

if __name__ == '__main__':
    data = np.mat([[0.0, 7.8031236401009165], [0.1, 4.426411521373964], [0.2, 8.947706236035394], [0.3, 8.200555461071133], [0.4, 0.5142706050449086], [0.5, 3.1008492860590753], [0.6, 7.1785370142703], [0.7, 3.872126889453009], [0.8, 1.025577102380758], [0.9, 4.833507884197839], [1.0, 0.9345186488455648], [1.1, 4.812032669803522], [1.2, 3.4191665496674375], [1.3, 0.022961590431520573], [1.4, 0.8785497842851486], [1.5, 3.2381682303766004], [1.6, 6.663617230163021], [1.7, 0.9093960835056158], [1.8, 3.680521995791508], [1.9, 4.900535080573809], [2.0, 0.5937324344804573], [2.1, 7.783916463844781], [2.2, 7.11179421378249], [2.3, 2.4446164389372083], [2.4, 0.31413070861756043], [2.5, 8.793728586574554], [2.6, 5.826793655403399], [2.7, 5.694595209349293], [2.8, 1.810999124588204], [2.9, 2.5891746519869896], [3.0, 7.989527941362203], [3.1, 3.8888255411877237], [3.2, 6.980965568743198], [3.3, 6.183466511452052], [3.4, 8.994788345820844], [3.5, 7.713561865457374], [3.6, 5.373007398355321], [3.7, 5.857041801728861], [3.8, 1.2873991273003293], [3.9, 7.90280175406894], [4.0, 2.114965228889154], [4.1, 1.9122653164064485], [4.2, 1.430294828685612], [4.3, 6.447866312224862], [4.4, 1.6034730966952893], [4.5, 5.020523561692603], [4.6, 8.434230931666896], [4.7, 8.926232142246732], [4.8, 7.21575735052221], [4.9, 0.8242497089572909], [5.0, 3.2773727950676923], [5.1, 4.789459791385374], [5.2, 5.7771008824695205], [5.3, 5.475006081618368], [5.4, 0.1095089272289862], [5.5, 3.708028945757401], [5.6, 5.868541457709577], [5.7, 3.82129152340557], [5.8, 4.672995214563882], [5.9, 7.139883221140032], [6.0, 1.9266318905278135], [6.1, 8.37752018393436], [6.2, 0.38698741212173093], [6.3, 6.809020238683873], [6.4, 6.7650503938175754], [6.5, 3.622205528236381], [6.6, 6.2275067570806755], [6.7, 9.272360909999943], [6.8, 6.937813473353321], [6.9, 6.719818328095723], [7.0, 4.1761213627699085], [7.1, 2.2186584343523554], [7.2, 4.325155968720139], [7.3, 6.346030576482492], [7.4, 6.681255604404174], [7.5, 0.9022441439624651], [7.6, 3.5585144004608313], [7.7, 3.0389210408048797], [7.8, 1.2906111157619915], [7.9, 5.5836998346201785], [8.0, 9.35888445552132], [8.1, 7.895786013311998], [8.2, 8.908751780051233], [8.3, 0.5462991244242466], [8.4, 5.430329386519083], [8.5, 5.632137779391064], [8.6, 8.919006991403911], [8.7, 1.881419639976668], [8.8, 0.8227664012604219], [8.9, 6.551076298674367], [9.0, 0.8128558657766949], [9.1, 3.8227215510033807], [9.2, 5.628006555285536], [9.3, 7.4386834204408645], [9.4, 3.7984910271263073], [9.5, 4.012712837404079], [9.6, 9.755417444517088], [9.7, 2.8443889672100973], [9.8, 4.143513529213513], [9.9, 4.093044136223663]])
    result = kmeans_building_from_sklearn(data, 10)  # 本例要分3类,所以传入一个3
    x = [to_array_2(data[i][:, 0:1]) for i in result[2]]
    y = [to_array_2(data[i][:, -1:]) for i in result[2]]
    for i, j, Kp in zip(x, y, result[1]):
        plt.scatter(i, j)
        plt.scatter(Kp[0], Kp[-1], c='black', marker='x', alpha=1.0)
    plt.show()
    plt.close()

与标准的平方误差和版本的k-means最大的优化点在于初始点的选择,选择了K点之间距离更远的K点作为初始值,从而保证更好的分类有效性,以及对簇之间的距离太近具有更好的宽容度。

  • 2
    点赞
  • 13
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值