KMeans算法实现

闲得,自己琢磨了KMeans算法,记录下。原理网络一大把,不再累述

# -*- coding: utf-8 -*-
"""
Created on Wed May 16 23:02:51 2018

@author: mz
"""
import math
import random
from sklearn import datasets
import numpy  as np
import copy

class  Cluster(object):
    def center(self, data, minPoints, k, n):
        pass
    
    def distance(self, source: list, center: list) -> float:
        """
        计算两点之间的欧拉距离,支持多维
        """       
        dst = 0.0
        for s, c in zip(source, center):
            dst += math.pow(s - c, 2)
        return math.sqrt(dst)
    
    def minDistance(self, point, centroids):
        
        min_dist = math.inf  # 初始设为无穷大
        index = -1;
        
        for i, cen in enumerate(centroids):
            dist = self.distance(cen, point)
            if dist < min_dist:
                min_dist = dist
                index = i
        return min_dist,index
    
    
    
        

class KMeans(Cluster):

    
    """
       data         : 输入数据
       minPoints    : 聚类的最小数据个数
       K            : 质心
       第一次的質心遠離聚類效果越好?
    """   
    def randomCenter(self,data, k):  
    
        centers = [[0 for col in range(k)] for row in range(1)]
        centers.append(random.choice(data)) #第一次随机选择质心   
       
        d = [0 for _ in range(len(data))]         
        for _ in range(1, k):          
            total = 0.0
            for i, point in enumerate(data):
                d[i],index = self.minDistance(point, centers) 
                total += d[i]

                
            total *= random.random()             
            for i, di in enumerate(d): # 轮盘法选出下一个聚类中心;  、、此处参考另一高人              
                total -= di
                if total > 0:
                    continue
                centers.append(data[i])
                break 

        return centers[1:] 
         
    
    def center(self, data, minPoints, k, n):
        
        preCenters = self.randomCenter(data, k)    
        centers = copy.deepcopy(preCenters)    
        
        for _ in range(0, n):  
            
            cluster = [[0 for col in range(1)] for row in range(k)]  
            for c in range(0, len(cluster)):                
                cluster[c].pop(0)

            for i, point in enumerate(data):
                dist,index = self.minDistance(point, preCenters)
                cluster[index].append(point)


            for group in range(0, len(cluster)):   
                p =  list(zip(*cluster[group]))

                for j in range(len(p)):
                    centers[group][j] = sum(p[j])*1.0/len(p[j])                   

           
            total = 0.0
            for i, point in enumerate(centers):         
                dist =  np.sqrt((np.mat(preCenters[i])-np.mat(point))*(np.mat(preCenters[i])-np.mat(point)).T)
                total += dist

            if total[0][0]/k < 0.0001:
                break
            
            preCenters = copy.deepcopy(centers)
        
        return centers,cluster
            
            
            
            


if __name__ == "__main__":
    
     iris = datasets.load_iris()
     kmean = KMeans()
     
     centers, clusters = kmean.center(iris.data,10,  3, 500)
     
     print("centers = ", centers)
     print("cluster[0] := ", clusters[0], "length := ", len(clusters[0]))
     print("cluster[1] := ", clusters[1], "length := ", len(clusters[1]))
     print("cluster[2] := ", clusters[2], "length := ", len(clusters[2]))

     
     
     
     
    

运行结果:

centers = 
 [array([ 6.85384615,  3.07692308,  5.71538462,  2.05384615]),
 array([ 5.88360656,  2.74098361,  4.38852459,  1.43442623]), 
 array([ 5.006,  3.418,  1.464,  0.244])]
 
cluster[0] :=  [array([ 7. ,  3.2,  4.7,  1.4]), array([ 6.9,  3.1,  4.9,  1.5]), array([ 6.7,  3. ,  5. ,  1.7]), array([ 6.3,  3.3,  6. ,  2.5]), array([ 7.1,  3. ,  5.9,  2.1]), array([ 6.3,  2.9,  5.6,  1.8]), array([ 6.5,  3. ,  5.8,  2.2]), array([ 7.6,  3. ,  6.6,  2.1]), array([ 7.3,  2.9,  6.3,  1.8]), array([ 6.7,  2.5,  5.8,  1.8]), array([ 7.2,  3.6,  6.1,  2.5]), array([ 6.5,  3.2,  5.1,  2. ]), array([ 6.4,  2.7,  5.3,  1.9]), array([ 6.8,  3. ,  5.5,  2.1]), array([ 6.4,  3.2,  5.3,  2.3]), array([ 6.5,  3. ,  5.5,  1.8]), array([ 7.7,  3.8,  6.7,  2.2]), array([ 7.7,  2.6,  6.9,  2.3]), array([ 6.9,  3.2,  5.7,  2.3]), array([ 7.7,  2.8,  6.7,  2. ]), array([ 6.7,  3.3,  5.7,  2.1]), array([ 7.2,  3.2,  6. ,  1.8]), array([ 6.4,  2.8,  5.6,  2.1]), array([ 7.2,  3. ,  5.8,  1.6]), array([ 7.4,  2.8,  6.1,  1.9]), array([ 7.9,  3.8,  6.4,  2. ]), array([ 6.4,  2.8,  5.6,  2.2]), array([ 6.1,  2.6,  5.6,  1.4]), array([ 7.7,  3. ,  6.1,  2.3]), array([ 6.3,  3.4,  5.6,  2.4]), array([ 6.4,  3.1,  5.5,  1.8]), array([ 6.9,  3.1,  5.4,  2.1]), array([ 6.7,  3.1,  5.6,  2.4]), array([ 6.9,  3.1,  5.1,  2.3]), array([ 6.8,  3.2,  5.9,  2.3]), array([ 6.7,  3.3,  5.7,  2.5]), array([ 6.7,  3. ,  5.2,  2.3]), array([ 6.5,  3. ,  5.2,  2. ]), array([ 6.2,  3.4,  5.4,  2.3])] length :=  39

cluster[1] :=  [array([ 6.4,  3.2,  4.5,  1.5]), array([ 5.5,  2.3,  4. ,  1.3]), array([ 6.5,  2.8,  4.6,  1.5]), array([ 5.7,  2.8,  4.5,  1.3]), array([ 6.3,  3.3,  4.7,  1.6]), array([ 4.9,  2.4,  3.3,  1. ]), array([ 6.6,  2.9,  4.6,  1.3]), array([ 5.2,  2.7,  3.9,  1.4]), array([ 5. ,  2. ,  3.5,  1. ]), array([ 5.9,  3. ,  4.2,  1.5]), array([ 6. ,  2.2,  4. ,  1. ]), array([ 6.1,  2.9,  4.7,  1.4]), array([ 5.6,  2.9,  3.6,  1.3]), array([ 6.7,  3.1,  4.4,  1.4]), array([ 5.6,  3. ,  4.5,  1.5]), array([ 5.8,  2.7,  4.1,  1. ]), array([ 6.2,  2.2,  4.5,  1.5]), array([ 5.6,  2.5,  3.9,  1.1]), array([ 5.9,  3.2,  4.8,  1.8]), array([ 6.1,  2.8,  4. ,  1.3]), array([ 6.3,  2.5,  4.9,  1.5]), array([ 6.1,  2.8,  4.7,  1.2]), array([ 6.4,  2.9,  4.3,  1.3]), array([ 6.6,  3. ,  4.4,  1.4]), array([ 6.8,  2.8,  4.8,  1.4]), array([ 6. ,  2.9,  4.5,  1.5]), array([ 5.7,  2.6,  3.5,  1. ]), array([ 5.5,  2.4,  3.8,  1.1]), array([ 5.5,  2.4,  3.7,  1. ]), array([ 5.8,  2.7,  3.9,  1.2]), array([ 6. ,  2.7,  5.1,  1.6]), array([ 5.4,  3. ,  4.5,  1.5]), array([ 6. ,  3.4,  4.5,  1.6]), array([ 6.7,  3.1,  4.7,  1.5]), array([ 6.3,  2.3,  4.4,  1.3]), array([ 5.6,  3. ,  4.1,  1.3]), array([ 5.5,  2.5,  4. ,  1.3]), array([ 5.5,  2.6,  4.4,  1.2]), array([ 6.1,  3. ,  4.6,  1.4]), array([ 5.8,  2.6,  4. ,  1.2]), array([ 5. ,  2.3,  3.3,  1. ]), array([ 5.6,  2.7,  4.2,  1.3]), array([ 5.7,  3. ,  4.2,  1.2]), array([ 5.7,  2.9,  4.2,  1.3]), array([ 6.2,  2.9,  4.3,  1.3]), array([ 5.1,  2.5,  3. ,  1.1]), array([ 5.7,  2.8,  4.1,  1.3]), array([ 5.8,  2.7,  5.1,  1.9]), array([ 4.9,  2.5,  4.5,  1.7]), array([ 5.7,  2.5,  5. ,  2. ]), array([ 5.8,  2.8,  5.1,  2.4]), array([ 6. ,  2.2,  5. ,  1.5]), array([ 5.6,  2.8,  4.9,  2. ]), array([ 6.3,  2.7,  4.9,  1.8]), array([ 6.2,  2.8,  4.8,  1.8]), array([ 6.1,  3. ,  4.9,  1.8]), array([ 6.3,  2.8,  5.1,  1.5]), array([ 6. ,  3. ,  4.8,  1.8]), array([ 5.8,  2.7,  5.1,  1.9]), array([ 6.3,  2.5,  5. ,  1.9]), array([ 5.9,  3. ,  5.1,  1.8])] length :=  61

cluster[2] :=  [array([ 5.1,  3.5,  1.4,  0.2]), array([ 4.9,  3. ,  1.4,  0.2]), array([ 4.7,  3.2,  1.3,  0.2]), array([ 4.6,  3.1,  1.5,  0.2]), array([ 5. ,  3.6,  1.4,  0.2]), array([ 5.4,  3.9,  1.7,  0.4]), array([ 4.6,  3.4,  1.4,  0.3]), array([ 5. ,  3.4,  1.5,  0.2]), array([ 4.4,  2.9,  1.4,  0.2]), array([ 4.9,  3.1,  1.5,  0.1]), array([ 5.4,  3.7,  1.5,  0.2]), array([ 4.8,  3.4,  1.6,  0.2]), array([ 4.8,  3. ,  1.4,  0.1]), array([ 4.3,  3. ,  1.1,  0.1]), array([ 5.8,  4. ,  1.2,  0.2]), array([ 5.7,  4.4,  1.5,  0.4]), array([ 5.4,  3.9,  1.3,  0.4]), array([ 5.1,  3.5,  1.4,  0.3]), array([ 5.7,  3.8,  1.7,  0.3]), array([ 5.1,  3.8,  1.5,  0.3]), array([ 5.4,  3.4,  1.7,  0.2]), array([ 5.1,  3.7,  1.5,  0.4]), array([ 4.6,  3.6,  1. ,  0.2]), array([ 5.1,  3.3,  1.7,  0.5]), array([ 4.8,  3.4,  1.9,  0.2]), array([ 5. ,  3. ,  1.6,  0.2]), array([ 5. ,  3.4,  1.6,  0.4]), array([ 5.2,  3.5,  1.5,  0.2]), array([ 5.2,  3.4,  1.4,  0.2]), array([ 4.7,  3.2,  1.6,  0.2]), array([ 4.8,  3.1,  1.6,  0.2]), array([ 5.4,  3.4,  1.5,  0.4]), array([ 5.2,  4.1,  1.5,  0.1]), array([ 5.5,  4.2,  1.4,  0.2]), array([ 4.9,  3.1,  1.5,  0.1]), array([ 5. ,  3.2,  1.2,  0.2]), array([ 5.5,  3.5,  1.3,  0.2]), array([ 4.9,  3.1,  1.5,  0.1]), array([ 4.4,  3. ,  1.3,  0.2]), array([ 5.1,  3.4,  1.5,  0.2]), array([ 5. ,  3.5,  1.3,  0.3]), array([ 4.5,  2.3,  1.3,  0.3]), array([ 4.4,  3.2,  1.3,  0.2]), array([ 5. ,  3.5,  1.6,  0.6]), array([ 5.1,  3.8,  1.9,  0.4]), array([ 4.8,  3. ,  1.4,  0.3]), array([ 5.1,  3.8,  1.6,  0.2]), array([ 4.6,  3.2,  1.4,  0.2]), array([ 5.3,  3.7,  1.5,  0.2]), array([ 5. ,  3.3,  1.4,  0.2])] length :=  50

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值