通过Python写了个简单的K-Means分类
具体方法其实很简单:
- 生成几类随机数据点points
- 随机生成K个中心点centers
- 对每个点point求取距离最近的中心点center 即分类
- 对每个分类集中的数据点求取平均点作为新的中心点坐标
- 如果所有新的中心点 和 旧中心点 的距离都小于一定阈值 说明分类完成;否则迭代
import matplotlib.pyplot as plt
import numpy as np
import random
from icecream import ic
from collections import defaultdict
from matplotlib.colors import BASE_COLORS
def random_centers(k, points):
for i in range(k):
#在原本的可能坐标中随机生成k个中心点
yield random.choice(points[:, 0]), random.choice(points[:, 1])
def mean(points):
#all_x,all_y都是列表
all_x, all_y = [x for x, y in points], [y for x, y in points]
return np.mean(all_x), np.mean(all_y)
def distance(p1, p2):
#求取两点之间的距离
x1, y1 = p1
x2, y2 = p2
return np.sqrt((x1 - x2) ** 2 + (y1 - y2)**2)
def draw_points(centers,centers_neighbor,colors):
#遍历每个中心点
for i, c in enumerate(centers):
#获取该中心点 所涵盖的point集合
_points = centers_neighbor[c]
all_x, all_y = [x for x, y in _points], [y for x, y in _points]
#将对应点绘制颜色
plt.scatter(all_x, all_y, c=colors[i])
plt.show()
def kmeans(k, points, centers=None):
#获取一个代表颜色信息值的列表
colors = list(BASE_COLORS.values())
#如果没有生成centers,则随机生成一个
if not centers:
centers = list(random_centers(k=k, points=points))
#方便调试
ic(centers)
for i, c in enumerate(centers):#enumerate() 将一个可遍历的数据对象(如列表、元组或字符串)组合为一个索引序列
plt.scatter([c[0]], [c[1]], s=90, marker='*', c=colors[i])#绘制散点图
plt.scatter(*zip(*points), c='black')
#defaultdict的作用是在于,当字典里的key不存在但被查找时,返回的不是keyError而是一个默认值 set对应set( ),即没有key时返回一个空集合
centers_neighbor = defaultdict(set)
for p in points:
#min函数返回的是一个 中心点坐标
closet_c = min(centers, key=lambda c: distance(p, c))
#将points加入最近的中心点集合
centers_neighbor[closet_c].add(tuple(p))
#ic(centers_neighbor)
draw_points(centers,centers_neighbor,colors)
new_centers = []
for c in centers_neighbor:
#对每个中心点所包含的所有点求其平均值,作为新的中心点
new_c = mean(centers_neighbor[c])
new_centers.append(new_c)
threshold = 0.1
distances_old_and_new = [distance(c_old, c_new) for c_old, c_new in zip(centers, new_centers)]
#ic(distances_old_and_new)
if all(c < threshold for c in distances_old_and_new):
return centers_neighbor
else:
kmeans(k, points, new_centers)
if __name__ == '__main__':
#随机生成四组数据
points0 = np.random.normal(loc=1, size=(100,2))
points1 = np.random.normal(loc=2, size=(100, 2))
points2 = np.random.normal(loc=4, size=(100, 2))
points3 = np.random.normal(loc=5, size=(100, 2))
points = np.concatenate([points0, points1, points2, points3])
kmeans(3,points=points,centers=None)
效果图:
第一次迭代:
第二次迭代:
第三次迭代:
第四次迭代:
第五次迭代:
分类完成!