机器学习13-LDA线性判别分析-python

16 篇文章 0 订阅
15 篇文章 0 订阅

1. 原理推导

在这里插入图片描述在这里插入图片描述在这里插入图片描述

2. python简单实现

def center(data_mat, target):
    clf_list = set(target)
    data = []
    
    for clf in clf_list:
        data.append(data_mat[target == clf])
    
    center0 = []
    for d in data:
        center0.append(np.mean(d, axis=0))
    
    return data, center0
def computer_sw_mat(data, center0, n):
    ret_mat = np.mat(np.zeros((n, n)))
    for i in range(len(data)):
        xi_center = center0[i]
        xi_center = np.mat(xi_center).T
        clf_data = data[i]
        for d in clf_data:
            d = np.mat(d).T
            
            ret_mat += (d - xi_center) * (d - xi_center).T
    return ret_mat
def computer_sb_mat(center0, n):
    ret_mat = np.mat(np.zeros((n, n)))
    k = len(center0)
    for i in range(k):
        
        for j in range(i+1, k):
            ci = np.mat(center0[i]).T
            cj = np.mat(center0[j]).T
            ret_mat += (ci - cj) * (ci - cj).T
    return ret_mat
def lda(dataset, target, k):
    data, center0 = center(dataset, target)
    # print center0
    m, n = dataset.shape
    sw_mat = computer_sw_mat(data, center0, n)
    # print sw_mat
    sb_mat = computer_sb_mat(center0, n)
    # print sb_mat
    eig_val, eig_vec = np.linalg.eig(sw_mat.I * sb_mat)
    print eig_val, eig_vec
    
    decom_mat = eig_vec[:k]
    new_data = dataset * decom_mat.T
    return new_data

3. sklearn

if __name__ == '__main__':
    from sklearn.datasets import load_iris
    iris = load_iris()
    dataset = iris.data
    target = iris.target
    new_data = lda(dataset, target, 2)
    # print new_data
    
    from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA1
    de = LDA1(solver='eigen', n_components=2)
    de.fit(dataset, target)
    newd = de.transform(dataset)
    print de.explained_variance_ratio_
    # print newd
    import matplotlib.pyplot as plt
    plt.subplot(311)
    plt.scatter(dataset[:,2].tolist(),dataset[:,1].tolist(), marker='p')
    plt.subplot(312)
    plt.scatter(new_data[:, 0].tolist(), new_data[:, 1].tolist(), marker='o')
    plt.subplot(313)
    plt.scatter(newd[:, 0].tolist(), newd[:, 1].tolist(), marker='x')
    plt.show()
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值