一、K-means算法原理
聚类的概念:一种无监督的学习,事先不知道类别,自动将相似的对象归到同一个簇中。
K-Means算法是一种聚类分析(cluster analysis)的算法,其主要是来计算数据聚集的算法,主要通过不断地取离种子点最近均值的算法。
K-Means算法主要解决的问题如下图所示。我们可以看到,在图的左边有一些点,我们用肉眼可以看出来有四个点群,但是我们怎么通过计算机程序找出这几个点群来呢?于是就出现了我们的K-Means算法。
图解:
从上图中,我们可以看到,A,B,C,D,E是五个在图中点。而灰色的点是我们的种子点,也就是我们用来找点群的点。有两个种子点,所以K=2。
然后,K-Means的算法如下:
-
随机在图中取K(这里K=2)个种子点。
-
然后对图中的所有点求到这K个种子点的距离,假如点Pi离种子点Si最近,那么Pi属于Si点群。(上图中,我们可以看到A,B属于上面的种子点,C,D,E属于下面中部的种子点)
-
接下来,我们要移动种子点到属于他的“点群"的中心。(见图上的第三步)
-
然后重复第(2)和第(3)步,直到,种子点没有移动(我们可以看到图中的第四步上面的种子点聚合了A,B,C,下面的种子点聚合了D,E)。
总结:K-Means算法步骤:
- 从数据中选择k个对象作为初始聚类中心;
- 计算每个聚类对象到聚类中心的距离来划分;
- 再次计算每个聚类中心。
- 计算标准测度函数,直到达到最大迭代次数,则停止,否则,继续操作。
- 确定最优的聚类中心。
K-means聚类方法总结
优点:
- 解决聚类问题的经典算法,简单
- 当处理大数据集时,该算法保持可伸缩性和高效率(与神经网络比)
- 当簇近似正态分布时,效果较好
缺点:
- 在簇的平均值可被定义的情况下才能使用,可能不适用于某些应用
- 必须事先给出k(要生成簇的数目),而且对初值敏感,即对于不同的初值,可能会导致不同结果
- 不适合非凸形状的簇或者大小差别很大的簇
- 对噪声和孤立点敏感
聚类算法应用举例
- 文档分类器
根据标签、主题和文档内容将文档分为多个不同的类别。这是一个非常标准且经典的K-means算法分类问题。首先,需要对文档进行初始化处理,将每个文档都用矢量来表示,并使用术语频率来识别常用术语进行文档分类,这一步很有必要。然后对文档向量进行聚类,识别文档组中的相似性。 - 客户分类
聚类能过帮助营销人员改善他们的客户群(在其目标区域内工作),并根据客户的购买历史、兴趣或活动监控来对客户类别做进一步细分。这是关于电信运营商如何将预付费客户分为充值模式、发送短信和浏览网站几个类别的白皮书。对客户进行分类有助于公司针对特定客户群制定特定的广告。 - 保险欺诈检测
机器学习在欺诈检测中也扮演着一个至关重要的角色,在汽车、医疗保险和保险欺诈检测领域中广泛应用。利用以往欺诈性索赔的历史数据,根据它和欺诈性模式聚类的相似性来识别新的索赔。由于保险欺诈可能会对公司造成数百万美元的损失,因此欺诈检测对公司来说至关重要。这是汽车保险中使用聚类来检测欺诈的白皮书。 - 乘车数据分析
面向大众公开的uber乘车信息的数据集,为我们提供了大量关于交通、运输时间、高峰乘车地点等有价值的数据集。分析这些数据不仅对uber大有好处,而且有助于我们对城市的交通模式进行深入的了解,来帮助我们做城市未来规划。
二、实战
重要参数:
- n_clusters:聚类的个数
重要属性:
- cluster_centers_:[n_clusters, n_features]的数组,表示聚类中心点的坐标
- labels_:每个样本点的标签
2.1 聚类实例
(1)聚类的基本使用
import numpy as np
import pandas as pd
import pyecharts.options as opts
import matplotlib.pyplot as plt
import seaborn as sns
from pyecharts.charts import Scatter
# 手动生成随机点做聚类
from sklearn.datasets import make_blobs
# 制作一个假的数据集
X, y = make_blobs(
n_samples=150,
n_features=2,
centers=3,
cluster_std=1.5,
random_state=2
)
# 下面都是在用pyecharts画画
def add_data(pic, X_data, y_data, symbol='circle', symbol_size=10):
X_data = pd.DataFrame(X_data).copy()
y_data = pd.Series(y_data).copy()
for i in y_data.drop_duplicates():
pic.add_xaxis(xaxis_data=X_data.loc[y_data == i].iloc[0:, 0].tolist())
pic.add_yaxis(
series_name=i,
y_axis=X_data.loc[y_data == i].iloc[0:, 1].tolist(),
symbol_size=symbol_size,
symbol=symbol,
label_opts=opts.LabelOpts(is_show=False)
)
import pyecharts.options as opts
from pyecharts.charts import Scatter, Grid
from pyecharts.globals import ThemeType
m_scatter = Scatter(init_opts=opts.InitOpts(width="400px", height="400px", theme=ThemeType.LIGHT))
add_data(m_scatter, X, y)
m_scatter.set_series_opts()
m_scatter.set_global_opts(
xaxis_opts=opts.AxisOpts(
type_="value",
name="x1",
splitline_opts=opts.SplitLineOpts(is_show=True)
),
yaxis_opts=opts.AxisOpts(
type_="value",
name="x2",
axistick_opts=opts.AxisTickOpts(is_show=True),
splitline_opts=opts.SplitLineOpts(is_show=True)
),
tooltip_opts=opts.TooltipO