python实现K-Means算法
k-means算法的实现原理就不再赘述,给大家说一下程序的大致思路。
程序定义了两个方法,一个是计算欧氏距离(也就是两点之间线段最短,用勾股定理求斜边的长度)一个就是冗余很大的均值junzhi方法,这个方法实现了算法中的求均值、求每次更新的聚类中心步骤,方法最后使用if条件判断程序的出口
if origin_center_data_new != origin_center_data: junzhi(origin_data, origin_center_data_new)
代码块的最后是程序的入口,提供了8对数据origin_data并且使用给定的初始聚类中心origin_center_data,这里还没有考虑随机生成聚类中心。
origin_data = [[2, 10], [2, 5], [8, 4], [5, 8], [7, 5], [6, 4], [1, 2], [4, 9]] origin_center_data = [[2, 10], [5, 8], [1, 2]] junzhi(data=origin_data, center=origin_center_data)
程序不是非常困难,没有标注释,我也许会有一天感慨道:这是我自己写的吗? ,需要耐心看完。嘿嘿
from math import sqrt
from numpy import *
# 计算欧氏距离
def o_dis(vector1, vector2):
return sqrt(pow((vector1[0] - vector2[0]), 2) + pow((vector1[1] - vector2[1]), 2))
def junzhi(data, center):
distence_virtul = 100
origin_data = data
origin_center_data = center
group0 = []
group1 = []
group2 = []
distence = []
center_group = []
for i in range(0, int((sum([len(arr) for arr in origin_data]) / 2))):
for j in range(0, int((sum([len(arr) for arr in origin_center_data]) / 2))):
center_data_iterator = origin_center_data[j]
origin_data_iterator = origin_data[i]
print("原始数据为" + str(origin_data_iterator), end='')
print("中心点为" + str(center_data_iterator), end='')
center_group.append(center_data_iterator)
# 计算最小距离
distence_real = o_dis(origin_data_iterator, center_data_iterator)
print("距离为:" + str(distence_real))
distence.append(distence_real)
print(distence)
print(center_group)
for m in range(0, len(distence)):
distence_real = distence[m]
if (distence_real <= distence_virtul):
distence_virtul = distence_real
i = -1
for x in distence:
i = i + 1
if str(distence_virtul) == (str(x)):
index = i
break
print("最小距离为" + str(distence_virtul))
print("最小距离对应的中心点为" + str(center_group[i]))
if (center_group[i] == origin_center_data[0]):
group0.append(origin_data_iterator)
elif (center_group[i] == origin_center_data[1]):
group1.append(origin_data_iterator)
else:
group2.append(origin_data_iterator)
print(group0)
print(group1)
print(group2)
print("-------------------------------------------------------------------------")
distence_virtul = 100
distence.clear()
center_group.clear()
# 计算均值,求出下一次的中心点
total_x = 0
total_y = 0
for i in range(0, len(group0)):
total_x = total_x + group0[i][0]
total_y = total_y + group0[i][1]
new_center_1 = [int(total_x / len(group0)), int(total_y / len(group0))]
print(new_center_1)
total_x = 0
total_y = 0
for i in range(0, len(group1)):
total_x = total_x + group1[i][0]
total_y = total_y + group1[i][1]
new_center_2 = [int(total_x / len(group1)), int(total_y / len(group1))]
print(new_center_2)
total_x = 0
total_y = 0
for i in range(0, len(group2)):
total_x = total_x + group2[i][0]
total_y = total_y + group2[i][1]
new_center_3 = [int(total_x / len(group2)), int(total_y / len(group2))]
print(new_center_3)
origin_center_data_new = [new_center_1, new_center_2, new_center_3]
print("-------------------------------------------------------------------------")
print(
"下一次的中心点" + str(origin_center_data_new) + "==================================================================")
if origin_center_data_new != origin_center_data:
junzhi(origin_data, origin_center_data_new)
return origin_center_data_new
origin_data = [[2, 10], [2, 5], [8, 4], [5, 8], [7, 5], [6, 4], [1, 2], [4, 9]]
origin_center_data = [[2, 10], [5, 8], [1, 2]]
junzhi(data=origin_data, center=origin_center_data)