©作者|善财童子
学校|西北工业大学
研究方向|机器学习/射频微波
在知乎看到一篇讲解线性判别分析(LDA,Linear Discriminant Analysis)的文章,感觉数学概念讲得不是很清楚,而且没有代码实现。所以童子在参考相关文章的基础上在这里做一个学习总结,与大家共勉,欢迎各位批评指正~~
注意:在不加说明的情况下,所有公式的向量均是列向量,这个也会反映到代码中。
本文的基本思路来自以下文章:
https://www.adeveloperdiary.com/data-science/machine-learning/linear-discriminant-analysis-from-theory-to-code/
基本概念和目标
线性判别分析是一种很重要的分类算法,同时也是一种降维方法(这个我还没想懂)。和 PCA 一样,LDA 也是通过投影的方式达到去除数据之间冗余的一种算法。
如下图所示的 2 类数据,为了正确的分类,我们希望这 2 类数据投影之后,同类的数据尽可能的集中(距离近,有重叠),不同类的数据尽可能的分开(距离远,无重叠),左图的投影不好,因为 2 类数据投影后有重叠,而右图投影之后可以很好地进行分类,因为投影之后的 2 类数据之间几乎没有重叠,只是类内重叠得很厉害,而这正是我们想要的结果。
正交投影
因为 LDA 用到了投影,所以这里有必要科普一下投影的知识。以二维平面为例,如图所示
我们要计算向量 在 上的投影 ,很显然 与 成比例关系: ,其中 是一个常数。我们使用向量正交的概念来求出这个常数 。在上图中,向量 , 与 垂直,它们的内积为 0,即 ,即
注意:对于两个向量 x 和 y, ,所以有