k-means简单实现
k-means算法是典型的无监督式聚类算法,目的是将数据自动分为多个组别,如下图数据的分布以及k均值聚类后得到的结果
算法内容
-
选择k个点,作为初始的聚类中心
k的选择大多情况下认为给定,无法给定时可尝试多个k值分析得到恰当的k值
-
遍历每个数据点,将其与其距离最近的中心点归为一类
-
计算每一类别的均值
-
将均值设为新的中心点
-
再反复上过程
可见该算法为迭代算法
算法实现
- 拿到数据集,先画图观察
def Draw(data):
data1 = pd.DataFrame(data.get('X'), columns=['X1', 'X2'])
sb.lmplot('X1', 'X2', data = data1, fit_reg=False)
plt.show()
- 随机设定初始的中心坐标
def init_center(X, k):
m, n = X.shape
centroids = np.zeros((k, n))
#从数据集中随机选定k个点
idx = np.random.randint(0, m, k)
for i in range(k):
center[i, :] = X[idx[i], :]
return center
- 将数据集与距离其最近的中心点归为一类
def find_closest_center(X, center):
m = X.shape[0]
k = center.shape[0]
cat = np.zeros(m)
for i in range(m):
min_dis = 100000
for j in range(k):
dis = np.sum((X[i,:]-center[j,:])**2)
if dis < min_dis:
min_dis = dis
cat[i] = j
return cat
- 计算每一类的均值
def compute_center(X, idx, k):
m, n = X.shape
dis = np.zeros((k, n))
for i in range(k):
cnt = np.where(idx == i)
dis[i,:] = (np.sum(X[cnt,:], axis=1) / len(cnt[0])).ravel()
return dis
- 更新新的中心点,以上过程迭代进行n次
def run_k_means(X, initial_center, max_iters):
m,n = X.shape
k = initial_center.shape[0]
center = initial_center
for i in range(max_iters):
idx = find_closest_center(X, center)
center = compute_center(X, idx, k)
return idx, center
完整代码
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sb
from scipy.io import loadmat
def Draw(data):
data1 = pd.DataFrame(data.get('X'), columns=['X1', 'X2'])
sb.lmplot('X1', 'X2', data = data1, fit_reg=False)
plt.show()
def find_closest_center(X, center):
m = X.shape[0]
k = center.shape[0]
cat = np.zeros(m)
for i in range(m):
min_dis = 100000
for j in range(k):
dis = np.sum((X[i,:]-center[j,:])**2)
if dis < min_dis:
min_dis = dis
cat[i] = j
return cat
def compute_center(X, idx, k):
m, n = X.shape
dis = np.zeros((k, n))
for i in range(k):
cnt = np.where(idx == i)
dis[i,:] = (np.sum(X[cnt,:], axis=1) / len(cnt[0])).ravel()
return dis
def run_k_means(X, initial_center, max_iters):
m,n = X.shape
k = initial_center.shape[0]
center = initial_center
for i in range(max_iters):
idx = find_closest_center(X, center)
center = compute_center(X, idx, k)
return idx, center
def init_center(X, k):
m, n = X.shape
center = np.zeros((k, n))
idx = np.random.randint(0, m, k)
for i in range(k):
center[i, :] = X[idx[i], :]
return center
data = loadmat('./data7/ex7data2.mat')
Draw(data)
X = data['X']
center = np.array([[3, 3], [6, 2], [8, 5]])
idx = find_closest_center(X, center)
print(idx[0:3])
dis = compute_center(X, idx, 3)
idx, center = run_k_means(X, center, 10)
cat1 = X[np.where(idx == 0)[0],:]
cat2 = X[np.where(idx == 1)[0],:]
cat3 = X[np.where(idx == 2)[0],:]
fig, ax = plt.subplots(figsize=(12,8))
ax.scatter(cat1[:,0], cat1[:,1], s=30, color='r')
ax.scatter(cat2[:,0], cat2[:,1], s=30, color='g')
ax.scatter(cat3[:,0], cat3[:,1], s=30, color='b')
ax.legend()
plt.show()
print(dis)