机器学习算法系列(十二)-二次判别分析算法(Quadratic Discriminant Analysis Algorithm)

阅读本文需要的背景知识点:线性判别分析、一丢丢编程知识

一、引言

  前面两节介绍了线性判别分析在不同角度下的实现方式,一种是根据费舍尔“类内小、类间大”的角度,另一种则是从概率分布的角度。本节来介绍另一种判别分析——二次判别分析算法1(Quadratic Discriminant Analysis Algorithm / QDA)

二、模型介绍

  同线性判别分析一样,从概率分布的角度来得到二次判别分析,区别在于线性判别分析假设每一种分类的协方差矩阵相同,而二次判别分析中每一种分类的协方差矩阵不同。

(1)同线性判别分析一样,我们的目的就是求在输入为 x 的情况下分类为 k 的概率最大的分类,所以我们可以写出假设函数如下图(1)式
(2)对其概率取对数,不影响函数的最后结果
(3)带入上面的 P ( k ∣ x ) P(k|x) P(kx) 的表达式,由于 P ( x ) P(x) P(x) 对最后结果也没有影响,也可以直接去掉
(4)带入多元正态分布的概率密度函数表达式,注意这里与线性判别分析的不同,协方差矩阵在每一种类型下是不同的
(5)将(4)式中的对数化简得到
(6)这时就不能和线性判别分析一样去掉第二项了,而是要保留其中协方差矩阵行列式的部分,得到最后的结果
h ( x ) = argmax ⁡ k P ( k ∣ x ) ( 1 ) = argmax ⁡ k ln ⁡ P ( k ∣ x ) ( 2 ) = argmax ⁡ k ln ⁡ f k ( x ) + ln ⁡ P ( k ) ( 3 ) = argmax ⁡ k ln ⁡ ( e − ( x − μ k ) T Σ k − 1 ( x − μ k ) 2 ∣ Σ k ∣ 1 2 ( 2 π ) p 2 ) + ln ⁡ P ( k ) ( 4 ) = argmax ⁡ k − 1 2 ( x − μ k ) T Σ k − 1 ( x − μ k ) − ln ⁡ ( ∣ Σ k ∣ 1 2 ( 2 π ) p 2 ) + ln ⁡ P ( k ) ( 5 ) = argmax ⁡ k − 1 2 ( x − μ k ) T Σ k − 1 ( x − μ k ) − 1 2 ln ⁡ ( ∣ Σ k ∣ ) + ln ⁡ P ( k ) ( 6 ) \begin{aligned} h(x) &=\underset{k}{\operatorname{argmax}} P(k \mid x) & (1)\\ &=\underset{k}{\operatorname{argmax}} \ln P(k \mid x) & (2)\\ &=\underset{k}{\operatorname{argmax}} \ln f_{k}(x)+\ln P(k) & (3) \\ &=\underset{k}{\operatorname{argmax}} \ln \left(\frac{e^{-\frac{\left(x-\mu_{k}\right)^{T}{\Sigma_{k}^{-1}\left(x-\mu_{k}\right)}}{2}}}{\left|\Sigma_{k}\right|^{\frac{1}{2}}(2 \pi)^{\frac{p}{2}}}\right)+\ln P(k) & (4) \\ &=\underset{k}{\operatorname{argmax}} -\frac{1}{2}\left(x-\mu_{k}\right)^{T} \Sigma_{k}^{-1}\left(x-\mu_{k}\right)-\ln \left(\left|\Sigma_{k}\right|^{\frac{1}{2}}(2 \pi)^{\frac{p}{2}}\right)+\ln P(k) & (5) \\ &=\underset{k}{\operatorname{argmax}} -\frac{1}{2}\left(x-\mu_{k}\right)^{T} \Sigma_{k}^{-1}\left(x-\mu_{k}\right)-\frac{1}{2} \ln \left(\left|\Sigma_{k}\right|\right)+\ln P(k) & (6) \end{aligned} h(x)=kargmaxP(kx)=kargmaxlnP(kx)=kargmaxlnfk(x)+lnP(k)=kargmaxlnΣk21(2π)2pe2(xμk)TΣk1(xμk)+lnP(k)=kargmax21(xμk)TΣk1(xμk)ln(Σk21(2π)2p)+lnP(k)=kargmax21(xμk)TΣk1(xμk)21ln(Σk)+lnP(k)(1)(2)(3)(4)(5)(6)

  观察上面的(6)式,可知是关于 x 的二次函数,所以这也是该算法被称为二次判别分析算法的原因。

三、代码实现

使用 Python 实现二次判别分析(QDA):

def qda(X, y):
   """
   二次判别分析(QDA)
   args:
       X - 训练数据集
       y - 目标标签值
   return:
       y_classes - 标签类别
       priors - 每类先验概率
       means - 每类均值向量
       sigmags - 每类协方差矩阵
       dets - 每类协方差矩阵行列式
   """
   # 标签值
   y_classes = np.unique(y)
   # 每类先验概率
   priors = []
   # 每类均值向量
   means = []
   # 每类协方差矩阵
   sigmags = []
   # 每类协方差矩阵行列式
   dets = []
   for idx in range(len(y_classes)):
       c = X[y==y_classes[idx]][:]
       # 先验概率
       prior = c.shape[0] / X.shape[0]
       priors.append(prior)
       # 均值向量
       mu = np.mean(c, axis=0)
       means.append(mu)
       # 协方差矩阵
       sigma = c - mu
       sigma = sigma.T.dot(sigma) / c.shape[0]
       sigmags.append(np.linalg.pinv(sigma))
       # 协方差矩阵行列式
       dets.append(np.linalg.det(sigma))
   return y_classes, priors, means, sigmags, dets

def discriminant(X, y_classes, priors, means, sigmags, dets):
   """
   判别新样本点
   args:
       X - 数据集
       y_classes - 标签类别
       priors - 每类先验概率
       means - 每类均值向量
       sigmags - 每类协方差矩阵
       dets - 每类协方差矩阵行列式
   return:
       分类结果
   """
   ps = []
   for idx in range(len(y_classes)):
       x = X - means[idx]
       p = - 0.5 * (np.sum(np.multiply(x.dot(sigmags[idx]), x), axis=1) + np.log(dets[idx])) + priors[idx]
       ps.append(p)
   return y_classes.take(np.array(ps).T.argmax(1))

四、第三方库实现

scikit-learn2 实现线性判别分析:

from sklearn.discriminant_analysis import QuadraticDiscriminantAnalysis

# 初始化二次判别分析器
qda = QuadraticDiscriminantAnalysis()
# 拟合数据
qda.fit(X, y)
# 预测数据
qda.predict(X)

  sklearn 的实现并没有像上面的实现一样直接去计算协方差矩阵的逆矩阵,而是通过奇异值分解(SVD)的方式避免直接求协方差矩阵的逆矩阵,计算复杂度会小很多,具体可参考 sklearn 文档3 中对协方差矩阵的估计算法。

五、示例演示

  下图展示了存在二种分类时的演示数据,其中红色表示标签值为 0 的样本、蓝色表示标签值为 1 的样本:
1.png

  下面两张图分别展示了线性判别分析和二次判别分析拟合数据的结果,其中浅红色表示拟合后根据权重系数计算出预测值为 0 的部分,浅蓝色表示拟合后根据权重系数计算出预测值为 1 的部分:
12.png
2.png

  可以很明显的看到两种判别分析的决策边界的不同,线性判别分析只能学习线性边界,而二次判别分析可以学习二次边界,因此具有更大的灵活性。

六、思维导图

3.jpeg

七、参考文献

  1. https://en.wikipedia.org/wiki/Quadratic_classifier#Quadratic_discriminant_analysis
  2. https://scikit-learn.org/stable/modules/generated/sklearn.discriminant_analysis.QuadraticDiscriminantAnalysis.html
  3. https://scikit-learn.org/stable/modules/lda_qda.html#estimation-algorithms

完整演示请点击这里

注:本文力求准确并通俗易懂,但由于笔者也是初学者,水平有限,如文中存在错误或遗漏之处,恳请读者通过留言的方式批评指正

本文首发于——AI导图,欢迎关注

  • 8
    点赞
  • 45
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值