2022.9.26西瓜书3.5线性判别分析编程

题目: 编程实现线性判别分析,并给出西瓜数据集3.0α上的结果。

-------------------------------------------------------------------------------------------------------------------------------- 

代码转自:西瓜书——第三章课后习题,本文就代码进行了解读和注释。

 ctrl+/快捷键可以对代码注释

import numpy as np
import math
import matplotlib.pyplot as plt

data_x = [[0.697, 0.460], [0.774, 0.376], [0.634, 0.264], [0.608, 0.318], [0.556, 0.215], [0.403, 0.237],
          [0.481, 0.149], [0.437, 0.211],
          [0.666, 0.091], [0.243, 0.267], [0.245, 0.057], [0.343, 0.099], [0.639, 0.161], [0.657, 0.198],
          [0.360, 0.370], [0.593, 0.042], [0.719, 0.103]]
data_y = [1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0]

#求出两个均值向量
mu_0 = np.mat([0., 0.]).T  
mu_1 = np.mat([0., 0.]).T  
count_0 = 0
count_1 = 0
for i in range(len(data_x)):
    x = np.mat(data_x[i]).T
    if data_y[i] == 1:
        mu_1 = mu_1 + x
        count_1 = count_1 + 1
    else:
        mu_0 = mu_0 + x
        count_0 = count_0 + 1
mu_0 = mu_0 / count_0
mu_1 = mu_1 / count_1

#类内散度矩阵
S_w = np.mat([[0, 0], [0, 0]])
for i in range(len(data_x)):
    # 注意:西瓜书的输入向量是列向量形式
    x = np.mat(data_x[i]).T
    if data_y[i] == 0:
        S_w = S_w + (x - mu_0) * (x - mu_0).T
    else:
        S_w = S_w + (x - mu_1) * (x - mu_1).T

 本题的精华如下,依据课本公式(3.39)可以求出w利用SVE(奇异值分解)求解出w。

 np.linalg.svd()函数用于计算奇异值矩阵。

如何理解奇异值矩阵可以参考:keyicka【学长小课堂】什么是奇异值分解SVD--SVD如何分解时空矩阵

#求解出w,w是二维列向量
u, sigmav, vt = np.linalg.svd(S_w)    #对类内散度矩阵进行奇异值分解,sigmav为奇异值矩阵
sigma = np.zeros([len(sigmav), len(sigmav)])    #创建一个2*2的0矩阵
for i in range(len(sigmav)):  
    sigma[i][i] = sigmav[i]   #还原出奇异值矩阵
sigma = np.mat(sigma)   
S_w_inv = vt.T * sigma.I * u.T
w = S_w_inv * (mu_0 - mu_1)

在求得w后,下面的代码将①好瓜坏瓜的点分别绘制加以不同的形状颜色,②绘制坐标轴,均值向量,③绘制wTx+b,④将各点的投影画出来。

#求w的三角关系
w_0 = w[0, 0]
w_1 = w[1, 0]
tan = w_1 / w_0
sin = w_1 / math.sqrt(w_0 ** 2 + w_1 ** 2)
cos = w_0 / math.sqrt(w_0 ** 2 + w_1 ** 2)

print(w_0, w_1)

#将两类点画出来,好瓜是三角形,坏瓜是圆形
for i in range(len(data_x)):
    if data_y[i] == 0:
        plt.plot(data_x[i][0], data_x[i][1], "go")
    else:
        plt.plot(data_x[i][0], data_x[i][1], "b^")

#绘制出两个类的均值向量和找出来的直线
plt.xlabel('x')
plt.ylabel('y')
plt.title('Linear Discriminant Analysis')
plt.plot(mu_0[0, 0], mu_0[1, 0], "ro")
plt.plot(mu_1[0, 0], mu_1[1, 0], "r^")
plt.plot([-0.1, 0.1], [-0.1 * tan, 0.1 * tan])  #横纵从-0.1到0.1,纵轴从下到上

#将点投影到找出来的直线上并绘制出来
for i in range(len(data_x)):
    x = np.mat(data_x[i]).T
    ell = w.T * x    #w和某个样本的的点乘得到该样本落在w上的长度
    ell = ell[0, 0]  #取出该值
    #绘制出落在w上的点
    if data_y[i] == 0:
        plt.scatter(cos * ell, sin * ell, marker='o', c='g', edgecolors='g')
    else:
        plt.scatter(cos * ell, sin * ell, marker='^', c='b', edgecolors='b')
plt.show()

运行结果如下: 

  • 1
    点赞
  • 27
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值