参照西瓜书的课后习题3.5的要求,参考了一些资料,简单地实现了一下LDA。
数据还是西瓜数据3.0a
代码和数据,都挂在了我的git上:https://github.com/qdbszsj/LDA
首先第一部分还是画一个散点图,这个跟上一个习题是一样的,此处不详细表述了。
然后是先用sklearn偷懒实现一下LDA,这里要注意下模型参数的选择,对于小数据一般选择lsqr,这里给出了官方的reference可以查一下。
之后是自己实现LDA的步骤详解:
u是均值向量,对于二分类问题,每个个体有两个属性值,因此均值向量求出来是2*2的矩阵,就是求一下平均数。然后根据西瓜书P61的公式3.33求出类内散度矩阵,这里注意一下是列向量乘以自己的转置,最后的类内散度矩阵(within-class scatter matrix)是2*2的,记为Sw。这里根据公式3.39,我们求出Sw的逆矩阵就可以了,然而这里我们不用np.linalg.inv()来求逆矩阵,而是要考虑到数值的稳定性,采用先用奇异值分解(SVD),再用分解出的矩阵得到一个类似原Sw的逆矩阵的东西,我不太明白这里为何要用SVD绕一圈,为何这样就增加了数值的稳定性?查阅了相关的资料,有说矩阵是相似的,特征值等比例缩放所以没关系,具体的说法也没搜到,通常都是一笔带过,不知道为何这里不能直接对Sw求逆矩阵,望dalao们指教一下。然后根据公式3.39我们就得到了w,也就是一条线的方向或者说是一个方向向量(x,y),把数据垂直映射到这条线上就行了。
<