目录
1、无监督算法简介
无监督学习,即对于没有类别标签的样本集,模型可以根据样本的特性,把样本划分为若干个子集(类),达到同类样本特性相差小、异类样本特性相差大的效果,实现自主分类,无师自通。
2、K-means算法思想
K-means又称K均值聚类法(距离平方和最小聚类法),其思想与KNN算法相似,但不同的是K-means的聚类中心点会变化,直到算法达到最优。即先从样本集中随机选取 (或人为给定)k个样本作为聚类中心点(可以简单理解为KNN中的训练数据,虽然量少),然后计算每个 聚类中心点到所有样本的距离,根据距离将每个样本与其距离最近的聚类中心划分为同一类,然后在每个类别中取平均值(numpy中mean()函数用于计算数组中元素的算术平均数)作为新的聚类中心点,实现聚类中心点的更新,重复上述步骤,直到聚类中心点不再发生变化,即算法收敛。
3、K-means算法实现步骤
- 选取k个中心点(随机选取或人为给定)
- 根据中心点将样本数据分类
- 根据第二步的分类结果,在每个类中重新确定新的中心点
- 判断是否继续迭代,即如果前一次中心点与本次中心点相同,则计算结束,否则重复步骤2
4、K-means示意图
5、代码实现
import matplotlib.pyplot as plt
import numpy as np
#导入数据
test_data=np.loadtxt(r'C:\Users\testSet.txt')
k=int(input('请输入类别数K:'))
num_data=np.shape(test_data)[0] #获取数据个数
data_label=np.zeros(num_data)#用于存储数据标签
#确定初始中心点
if k==3:
num_center=np.array([[-4.822,4.607],[-0.7188,-2.493],[4.377,4.864]])#根据题意,k=3时,用给定的中心点
num_before=num_center
else:
num_center=np.mat(np.zeros((k,2)))#其他情况随机给定k个对应的中心点
for i in range(2):
min_xy=min(test_data[:,i])#分别获取x,y坐标的最小值
range_i=float(max(test_data[:,i]-min_xy))#分别获取x,y坐标的对应的最大差值
num_center[:,i]=np.mat(min_xy+range_i*np.random.rand(k,1))#基于最小值,保证数据合理,确定中心点的x,y坐标
num_before=num_center
continue_calculate=True #用于是否继续计算,即是否继续迭代
while continue_calculate:
continue_calculate=False #改变迭代标记
for i in range(num_data):
min_dist=np.inf #给最小距离一个较大的初始值,此处取无穷大
min_index=-1 #用于后续标记的类别
for j in range(k):
d=np.sqrt(np.sum(np.power(num_center[j,:]-test_data[i,:],2))) #求出个点与中心点的欧式距离
if d<min_dist:
min_dist=d
min_index=j #用k个值分别标记的数据的类别
if data_label[i]!=min_index:#判断是否进行迭代
continue_calculate=True
data_label[i]=min_index #用数据打好标签
#如果if语句被执行,则表明需要继续迭代,则需要更新中心点
#更新中心点
for i in range(k):
num=test_data[np.nonzero(data_label[:]==i)[0]]#获取同一类别的x,y坐标值
num_center[i,:]=np.mean(num,axis=0)#用mean函数求出同一类别的x,y坐标值对应的平均值
print(num_center)#输出中心点
#绘图
colors=['red','green','blue','black','orange','purple','yellow','pink']
for i,col in zip(range(k),colors):#
xy=np.empty([0,2])
for j,m in enumerate(data_label):
if m==i:
xy=np.append(xy,[[test_data[j,0],test_data[j,1]]],axis=0)
plt.scatter(xy[:,0],xy[:,1],c=col,label=('data in group',i))
plt.scatter(num_before[:,0].tolist(),num_before[:,1].tolist(),c='pink',marker='s',label='before center')
plt.scatter(num_center[:,0].tolist(),num_center[:,1].tolist(),c='y',marker='^',label='after center')
plt.legend()
plt.show()
6、结果展示
当k=5时,第一次迭代结果
当k=5 时,最终结果
本文仅供学习交流