机器学习之线性判别分析LDA

线性判别分析(Linear Discriminant Analysis, 简称LDA)是一种经典的线性学习方法,LDA算法既可以用来监督式的降维,也可以用来分类。
注意:LDA主题模型是指文本建模的文档主题生成模型,LDA是Latent Dirichlet Allocation的简称。

1模型优化

LDA思想:给定训练样例集,设法将样例投影到一条直线上,使得同类样例的投影点尽可能接近、异类样例的投影点尽可能远离。
lda示意图
欲使同类样例的投影点尽可能接近,可以让同类样例投影点的协方差尽可能小,即 w T ∑ 0 w + w T ∑ 1 w w^T\sum\nolimits_{0}w+w^T\sum\nolimits_{1}w wT0w+wT1w尽可能小;
而欲使异类样例的投影点尽可能远离,可以让类中心之间的距离尽可能大,即 ∥ w T μ 0 − w T μ 1 ∥ 2 2 \left \| w^T\mu_{0}-w^T\mu_{1}\right \|_{2}^{2} wTμ0wTμ122尽可能大。
也可从贝叶斯决策理论的角度来阐释,并可证明,当两类数据同先验、满足高斯分布且协方差相等时,LDA可达到最优分类。

2Sklearn代码实现

在Sklearn库中逻辑回归模型使用 LinearDiscriminantAnalysis类,其求解器(solver)可使用的优化算法,包括svd(奇异值分解) 、lsqr(最小二乘)和 eigen(特征分解)。

  • 默认的 solver 是 ‘svd’。
    它可以进行classification (分类) 以及 transform (转换),而且它不会依赖于协方差矩阵的计算(结果)。这在特征数量特别大的时候十分具有优势。然而,’svd’ solver 无法与 shrinkage (收缩)同时使用。
  • lsqr solver 则是一个高效的算法,它仅用于分类使用。它支持 shrinkage (收缩)。
  • eigen(特征) solver 是基于 class scatter (类散度)与 class scatter ratio (类内离散率)之间的优化。
    它可以被用于 classification (分类)以及 transform (转换),此外它还同时支持收缩。然而,该解决方案需要计算协方差矩阵,因此它可能不适用于具有大量特征的情况。

Shrinkage(收缩)是一种在训练样本数量相比特征而言很小的情况下可以提升的协方差矩阵预测(准确性)的工具。

  • shrinkage parameter (收缩参数)的值 可以设置为‘auto’ ,同样也可以手动被设置为 0-1 之间。
    特别地,0 值对应着没有收缩(这意味着经验协方差矩阵将会被使用), 而 1 值则对应着完全使用收缩(意味着方差的对角矩阵将被当作协方差矩阵的估计)。
    n_components(希望降到的维数):即我们进行LDA降维时降到的维数。在降维时需要输入这个参数。注意只能为[1,类别数-1)范围之间的整数。如果我们不是用于降维,则这个值可以用默认的None。

示例:鸢尾花数据集分类任务
将前文【机器学习之逻辑回归(对率回归)】中创建模型的语句更改为如下即可:

...
# 创建模型
model = LinearDiscriminantAnalysis()
...

示例:鸢尾花数据集降维

import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from sklearn import datasets
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis

# 加载数据
iris = datasets.load_iris()
# 使用样本的所有特征
x = iris.data  # print(iris.feature_names)
y = iris.target
label_dict = iris.target_names

# 创建模型
model = LinearDiscriminantAnalysis(n_components=2)
# 训练且降维
x_reduce = model.fit_transform(x, y)

# 样本前3个特征绘图,立体展示数据
fig = plt.figure()
ax3d = Axes3D(fig, rect=[0, 0, 1, 1], elev=20, azim=20)
ax3d.scatter(x[:, 0], x[:, 1], x[:, 2], c=y, cmap='brg')
plt.show()

# 降维数据的散点图
for label, marker, color in zip(range(0, 3), ('*', 's', 'o'), ('blue', 'red', 'green')):
plt.scatter(x=x_reduce[y == label][:, 0],
	    y=x_reduce[y == label][:, 1],
	    marker=marker,
	    color=color,
	    alpha=0.5,
	    label=label_dict[label])
plt.title('Iris Reduced LDA')
plt.xlabel('x', fontsize=14)
plt.ylabel('y', fontsize=14)
plt.legend(loc='upper right', fancybox=True)
plt.tick_params(labelsize=10)
plt.show()

运行效果图,如下:
3D
reduced

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值