线性回归

136 篇文章 17 订阅
39 篇文章 16 订阅

线性回归问题

W X = Y WX=Y WX=Y已知 X , Y X,Y X,Y,欲求 W W W,其中 W ∈ R L × M , X ∈ R M × N , Y ∈ R L × N W\in R^{L\times M}, X\in R^{M\times N}, Y\in R^{L\times N} WRL×M,XRM×N,YRL×N


最小二乘法

W = a r g m i n ∣ ∣ W X − Y ∣ ∣ 2 2 W=argmin ||WX-Y||^2_2 W=argminWXY22解得 W = Y X T ( X X T ) − 1 W = YX^T(XX^T)^{-1} W=YXT(XXT)1

岭回归

W = a r g m i n ∣ ∣ W X − Y ∣ ∣ 2 2 + β ∣ ∣ W ∣ ∣ 2 2 W=argmin ||WX-Y||^2_2+\beta ||W||^2_2 W=argminWXY22+βW22解得 W = Y X T ( X X T + β I ) − 1 W = YX^T(XX^T+\beta I)^{-1} W=YXT(XXT+βI)1

奇异值分解

X = U Λ V X = U\Lambda V X=UΛV其中, U ∈ R M × K , V ∈ R K × N , Λ ∈ R K × K , K = m i n { M , N } U\in R^{M\times K}, V\in R^{K\times N}, \Lambda \in R^{K\times K}, K=min\{M,N\} URM×K,VRK×N,ΛRK×K,K=min{M,N}。因此,
W = Y V T Λ − 1 U T W=YV^T\Lambda^{-1}U^T W=YVTΛ1UT为了避免 X X X 有特征值为 0 使得 Λ = d i a g { λ i } \Lambda=diag\{\lambda_i\} Λ=diag{λi} 不可逆: Λ − 1 = d i a g { λ i λ i 2 + ϵ } \Lambda^{-1}=diag\{\frac{\lambda_i}{\lambda_i^2+\epsilon}\} Λ1=diag{λi2+ϵλi}其中 ϵ \epsilon ϵ 为非常小的正数( 1 0 − 6 10^{-6} 106)。

实验

import numpy as np
import time


def MSE(X,Y):
    return  np.mean((X-Y)**2)


def solve_WX_Y(X,Y, method='SVD', reg=1e-6):
    if method == 'SVD':  # SVD, accurate method
        U, s, V = np.linalg.svd(X, full_matrices=False);
#         print('reconstruction error:',MSE(X,U.dot(np.diag(s)).dot(V)))
        s = s / (s ** 2 + reg)

        W = Y.dot(np.multiply(V.T, np.expand_dims(s, 0)).dot(U.T));

    else:  # NormalEquation, fast method
        B = Y.dot(X.T)
        A = X.dot(X.T)

        W = np.linalg.solve((A + np.eye(A.shape[0], A.shape[1]) * reg), B.T).T
        
    return W



L = 1000
M = 5000
N = 5000

X = np.random.random((M,N))
W = np.random.random((L,M))
Y = np.dot(W,X)+np.random.random((L,N))*0.01

lb = 0

print('---------SVD---------')
time_start=time.time()
W_solve = solve_WX_Y(X,Y,method='SVD',reg=lb)
time_end=time.time()
print('time:',time_end-time_start)
print('mse :',MSE(W,W_solve))


print('--------Rigde--------')
time_start=time.time()
W_solve = solve_WX_Y(X,Y,method='Ridge', reg=lb)
time_end=time.time()
print('time:',time_end-time_start)
print('mse :',MSE(W,W_solve))

'''
---------SVD---------
time: 0.5515260696411133
mse : 2.5135123133318373e-08
--------Rigde--------
time: 0.12566399574279785
mse : 2.513512313268334e-08
'''

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值