目录
本篇理论性不多,主要是部分总结及实战内容。
一、EM算法的步骤
EM算法(英文叫做Expectation Maximization,最大期望算法)三个主要的步骤:
- 初始化参数
- 观察预期
- 重新估计
二、EM算法的工作原理
EM算法一般用于聚类,也就是无监督模型里面,因为无监督学习没有标签,EM算法可以先给无监督学习估计一个隐状态(即标签),有了标签,算法模型就可以转换成有监督学习,这时就可以用极大似然估计法求解出模型最优参数。其中估计隐状态流程应为EM算法的E步,后面用极大似然估计为M步。
下面介绍一下两种不同类型的聚类算法:
- 硬聚类算法:如K-Means ,是通过距离来区分样本之间的差别的,且每个样本在计算的时候只能属于一个分类。
- 软聚类算法:如EM 聚类,它在求解的过程中,实际上每个样本都有一定的概率和每个聚类相关。
EM 算法相当于一个框架,你可以采用不同的模型来进行聚类,比如 GMM(高斯混合模型),或者 HMM(隐马尔科夫模型)来进行聚类。
- GMM 是通过概率密度来进行聚类,聚成的类符合高斯分布(正态分布)。
- HMM 用到了马尔科夫过程,在这个过程中,我们通过状态转移矩阵来计算状态转移的概率。
三、在sklearn中创建GMM模型
本案例采用GMM高斯混合模型。因此将介绍下如何在sklearn中创建GMM聚类。
gmm = GaussianMixture(n_components=1, covariance_type='full', max_iter=100)
看一下这几个参数:
1. n_components:即高斯混合模型的个数,也就是我们要聚类的个数,默认值为 1。如果你不指定 n_components,最终的聚类结果都会为同一个值。
2. covariance_type:代表协方差类型。一个高斯混合模型的分布是由均值向量和协方差矩阵决定的,所以协方差的类型也代表了不同的高斯混合模型的特征。协方差类型有 4 种取值:
- covariance_type=full,代表完全协方差,也就是元素都不为 0;
- covariance_type=tied,代表相同的完全协方差;
- covariance_type=diag,代表对角协方差,也就是对角不为 0,其余为 0;
- covariance_type=spherical,代表球面协方差,非对角为 0,对角完全相同,呈现球面的特性。
3. max_iter:代表最大迭代次数,EM 算法是由 E 步和 M 步迭代求得最终的模型参数,这里可以指定最大迭代次数,默认值为 100。
创建完GMM聚类器之后,可以传入数据让它进行迭代拟合。我们使用 fit 函数,传入样本特征矩阵,模型会自动生成聚类器,然后使用 prediction=gmm.predict(data) 来对数据进行聚类,传入你想进行聚类的数据,可以得到聚类结果 prediction。
四、工作流程
我们使用王者荣耀英雄数据集来进行聚类,数据包括 69 名英雄的 23 个特征属性。这些属性分别是,英雄,最大生命,生命成长,初始生命,最大法力,法力成长,初始法力,最高物攻,物攻成长,初始物攻,最大物防,物防成长,初始物防,最大每5秒回血,每5秒回血成长,初始每5秒回血,最大每5秒回蓝,每5秒回蓝成长,初始每5秒回蓝,最大攻速,攻击范围,主要定位,次要定位。
王者荣耀英雄数据集链接:https://github.com/cystanford/EM_data
先划分一下流程:
整个训练过程基本上都会包括三个阶段:
-
首先加载数据集;
-
在准备阶段,我们需要对数据进行探索,包括采用数据可视化技术,让我们对英雄属性以及这些属性之间的关系理解更加深刻,然后对数据质量进行评估,是否进行数据清洗,最后进行特征选择方便后续的聚类算法;
-
聚类阶段:选择适合的聚类模型,这里我们采用 GMM 高斯混合模型进行聚类,并输出聚类结果,对结果进行分析。
五、实战环节
1. 导包
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.preprocessing import StandardScaler
from sklearn.mixture import GaussianMixture
2. 加载数据
# 数据加载,避免中文乱码问题
data_original = pd.read_csv('dataset/heros.csv', encoding = 'gb18030')
3. 数据可视化分析
data_original.head().append(data_original.tail()) # 显示前5行和后5行数据
英雄 | 最大生命 | 生命成长 | 初始生命 | 最大法力 | 法力成长 | 初始法力 | 最高物攻 | 物攻成长 | 初始物攻 | ... | 最大每5秒回血 | 每5秒回血成长 | 初始每5秒回血 | 最大每5秒回蓝 | 每5秒回蓝成长 | 初始每5秒回蓝 | 最大攻速 | 攻击范围 | 主要定位 | 次要定位 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 夏侯惇 | 7350 | 288.8 | 3307 | 1746 | 94 | 430 | 321 | 11.570 | 159 | ... | 98 | 3.357 | 51 | 37 | 1.571 | 15 | 28.00% | 近战 | 坦克 | 战士 |
1 | 钟无艳 | 7000 | 275.0 | 3150 | 1760 | 95 | 430 | 318 | 11.000 | 164 | ... | 92 | 3.143 | 48 | 37 | 1.571 | 15 | 14.00% | 近战 | 战士 | 坦克 |
2 | 张飞 | 8341 | 329.4 | 3450 | 100 | 0 | 100 | 301 | 10.570 | 153 | ... | 115 | 4.143 | 57 | 5 | 0.000 | 5 | 14.00% | 近战 | 坦克 | 辅助 |
3 | 牛魔 | 8476 | 352.8 | 3537 | 1926 | 104 | 470 | 273 | 8.357 | 156 | ... | 117 | 4.214 | 58 | 42 | 1.786 | 17 | 14.00% | 近战 | 坦克 | 辅助 |
4 | 吕布 | 7344 | 270.0 | 3564 | 0 | 0 | 0 | 343 | 12.360 | 170 | ... | 97 | 3.071 | 54 | 0 | 0.000 | 0 | 14.00% | 近战 | 战士 | 坦克 |
64 | 阿轲 | 5968 | 192.8 | 3269 | 0 | 0 | 0 | 427 | 17.860 | 177 | ... | 81 | 2.214 | 50 | 0 | 0.000 | 0 | 28.00% | 近战 | 刺客 | NaN |
65 | 娜可露露 | 6205 | 211.9 | 3239 | 1808 | 97 | 450 | 385 | 15.140 | 173 | ... | 79 | 2.286 | 47 | 38 | 1.571 | 16 | 14.00% | 近战 | 刺客 | NaN |
66 | 兰陵王 | 6232 | 210.0 | 3292 | 1822 | 98 | 450 | 388 | 15.500 | 171 | ... | 99 | 3.357 | 52 | 46 | 1.929 | 19 | 14.00% | 近战 | 刺客 | NaN |
67 | 铠 | 6700 | 237.5 | 3375 | 1784 | 96 | 440 | 328 | 10.860 | 176 | ... | 81 | 2.643 | 44 | 38 | 1.571 | 16 | 28.00% | 近战 | 战士 | 坦克 |
68 | 百里守约 | 5611 | 185.1 | 3019 | 1784 | 96 | 440 | 410 | 15.860 | 188 | ... | 68 | 2.071 | 39 | 38 | 1.571 | 16 | 28.00% | 远程 | 射手 | 刺客 |
# 对英雄属性之间的关系进行可视化分析
# 设置plt正确显示中文
plt.rcParams['font.sans-serif']=['SimHei'] #用来正常显示中文标签
plt.rcParams['axes.unicode_minus']=False #用来正常显示负号
# 用热力图呈现特征之间的相关性
corr = data.corr()
plt.figure(figsize=(14,14))
# annot=True显示每个方格的数据
sns.heatmap(corr, annot=True)
plt.show()
我们将 18 个英雄属性之间的关系用热力图呈现了出来,中间的数字代表两个属性之间的关系系数,最大值为 1,代表完全正相关,关系系数越大代表相关性越大。
从图中你能看出来“最大生命”“生命成长”和“初始生命”这三个属性的相关性大,我们只需要保留一个属性即可。同理我们也可以对其他相关性大的属性进行筛选,保留一个。 这既是对原有属性进行降维。
4. 特征工程
# 相关性大的属性保留一个,因此可以对属性进行降维
features_remain = [u'最大生命', u'初始生命', u'最大法力', u'最高物攻', u'初始物攻', u'最大物防', u'初始物防',
u'最大每5秒回血', u'最大每5秒回蓝', u'初始每5秒回蓝', u'最大攻速', u'攻击范围']
data = data_original[features_remain]
data
最大生命 | 初始生命 | 最大法力 | 最高物攻 | 初始物攻 | 最大物防 | 初始物防 | 最大每5秒回血 | 最大每5秒回蓝 | 初始每5秒回蓝 | 最大攻速 | 攻击范围 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 7350 | 3307 | 1746 | 321 | 159 | 397 | 101 | 98 | 37 | 15 | 28.00% | 近战 |
1 | 7000 | 3150 | 1760 | 318 | 164 | 409 | 100 | 92 | 37 | 15 | 14.00% | 近战 |
2 | 8341 | 3450 | 100 | 301 | 153 | 504 | 125 | 115 | 5 | 5 | 14.00% | 近战 |
3 | 8476 | 3537 | 1926 | 273 | 156 | 394 | 109 | 117 | 42 | 17 | 14.00% | 近战 |
4 | 7344 | 3564 | 0 | 343 | 170 | 390 | 99 | 97 | 0 | 0 | 14.00% | 近战 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
64 | 5968 | 3269 | 0 | 427 | 177 | 349 | 89 | 81 | 0 | 0 | 28.00% | 近战 |
65 | 6205 | 3239 | 1808 | 385 | 173 | 359 | 86 | 79 | 38 | 16 | 14.00% | 近战 |
66 | 6232 | 3292 | 1822 | 388 | 171 | 342 | 85 | 99 | 46 | 19 | 14.00% | 近战 |
67 | 6700 | 3375 | 1784 | 328 | 176 | 388 | 107 | 81 | 38 | 16 | 28.00% | 近战 |
68 | 5611 | 3019 | 1784 | 410 | 188 | 329 | 94 | 68 | 38 | 16 | 28.00% | 远程 |
5. 数据规范化
我们能看到“最大攻速”这个属性值是百分数,不适合做矩阵运算,因此我们需要将百分数转化为小数。我们也看到“攻击范围”这个字段的取值为远程或者近战,也不适合矩阵运算,我们将取值做个映射,用 1 代表远程,0 代表近战。然后采用 Z-Score 规范化,对特征矩阵进行规范化。
data[u'最大攻速'] = data[u'最大攻速'].apply(lambda x: float(x.strip('%'))/100)
data[u'攻击范围'] = data[u'攻击范围'].map({'远程':1,'近战':0})
# 采用Z-Score规范化数据,保证每个特征维度的数据均值为0,方差为1
ss = StandardScaler()
data = ss.fit_transform(data)
6. 建模并产生结果,写入文件
# 构造GMM聚类
gmm = GaussianMixture(n_components=30, covariance_type='full')
gmm.fit(data)
# 训练数据
prediction = gmm.predict(data)
print(prediction)
# 将分组结果输出到CSV文件中
data_original.insert(0, '分组', prediction)
data_original.to_csv('./EM_data/heros_out.csv', index=False, sep=',')
[ 2 13 6 8 26 3 0 6 21 13 7 13 21 20 17 7 27 21 26 5 9 5 5 5
5 5 5 1 4 23 20 4 16 4 23 4 4 19 12 16 16 4 4 4 29 16 13 12
13 29 24 14 10 11 11 2 25 13 22 26 25 10 15 2 18 14 14 28 1]
我们采用了 GMM 高斯混合模型,并将结果输出到 CSV 文件中。聚类个数为 30。
7. 显示聚类后的结果
data_group = pd.read_csv('./EM_data/heros_out.csv')
data_group.sort_values('分组')
分组 | 英雄 | 最大生命 | 生命成长 | 初始生命 | 最大法力 | 法力成长 | 初始法力 | 最高物攻 | 物攻成长 | ... | 最大每5秒回血 | 每5秒回血成长 | 初始每5秒回血 | 最大每5秒回蓝 | 每5秒回蓝成长 | 初始每5秒回蓝 | 最大攻速 | 攻击范围 | 主要定位 | 次要定位 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
6 | 0 | 芈月 | 6164 | 281.5 | 3105 | 100 | 0 | 100 | 289 | 9.786 | ... | 77 | 2.357 | 44 | 0 | 0.000 | 0 | 0.00% | 远程 | 法师 | 坦克 |
68 | 1 | 百里守约 | 5611 | 185.1 | 3019 | 1784 | 96 | 440 | 410 | 15.860 | ... | 68 | 2.071 | 39 | 38 | 1.571 | 16 | 28.00% | 远程 | 射手 | 刺客 |
27 | 1 | 成吉思汗 | 5799 | 198.0 | 3027 | 1742 | 93 | 440 | 394 | 15.000 | ... | 66 | 2.071 | 37 | 36 | 1.500 | 15 | 42.00% | 远程 | 射手 | NaN |
63 | 2 | 哪吒 | 7268 | 270.4 | 3483 | 1808 | 97 | 450 | 320 | 11.500 | ... | 98 | 3.214 | 53 | 38 | 1.571 | 16 | 28.00% | 近战 | 战士 | NaN |
55 | 2 | 杨戬 | 7420 | 291.5 | 3339 | 1694 | 91 | 420 | 325 | 11.360 | ... | 98 | 3.357 | 51 | 36 | 1.500 | 15 | 28.00% | 近战 | 战士 | NaN |
0 | 2 | 夏侯惇 | 7350 | 288.8 | 3307 | 1746 | 94 | 430 | 321 | 11.570 | ... | 98 | 3.357 | 51 | 37 | 1.571 | 15 | 28.00% | 近战 | 坦克 | 战士 |
5 | 3 | 亚瑟 | 8050 | 316.3 | 3622 | 0 | 0 | 0 | 346 | 13.000 | ... | 106 | 3.643 | 55 | 0 | 0.000 | 0 | 14.00% | 近战 | 战士 | 坦克 |
31 | 4 | 甄姬 | 5584 | 181.6 | 3041 | 2002 | 108 | 490 | 296 | 9.357 | ... | 71 | 2.000 | 43 | 44 | 1.857 | 18 | 14.00% | 远程 | 法师 | NaN |
33 | 4 | 干将莫邪 | 5583 | 171.0 | 3189 | 1946 | 104 | 490 | 292 | 9.500 | ... | 71 | 1.857 | 45 | 41 | 1.714 | 17 | 14.00% | 远程 | 法师 | NaN |
41 | 4 | 小乔 | 5916 | 202.0 | 3088 | 1988 | 107 | 490 | 263 | 7.857 | ... | 75 | 2.214 | 44 | 43 | 1.786 | 18 | 14.00% | 远程 | 法师 | NaN |
第一列代表的是分组(簇),我们能看到百里守约和成吉思汗分到了一组,哪吒、杨戬和夏侯惇是一组,亚瑟自己是一组,甄姬、干将莫邪和小乔是一组。
聚类的特点是相同类别之间的属性值相近,不同类别的属性值差异大。
8. 聚类结果的评估
聚类和分类不一样,聚类是无监督的学习方式,也就是我们没有实际的结果可以进行比对,所以聚类的结果评估不像分类准确率一样直观,那么有没有聚类结果的评估方式呢?这里我们可以采用 Calinski-Harabaz 指标,代码如下:
from sklearn.metrics import calinski_harabaz_score
print(calinski_harabaz_score(data, prediction))
20.273576816244606
指标分数越高,代表聚类效果越好,也就是相同类中的差异性小,不同类之间的差异性大。当然具体聚类的结果含义,我们需要人工来分析,也就是当这些数据被分成不同的类别之后,具体每个类表代表的含义。