线性回归问题
W X = Y WX=Y WX=Y已知 X , Y X,Y X,Y,欲求 W W W,其中 W ∈ R L × M , X ∈ R M × N , Y ∈ R L × N W\in R^{L\times M}, X\in R^{M\times N}, Y\in R^{L\times N} W∈RL×M,X∈RM×N,Y∈RL×N。
最小二乘法
W = a r g m i n ∣ ∣ W X − Y ∣ ∣ 2 2 W=argmin ||WX-Y||^2_2 W=argmin∣∣WX−Y∣∣22解得 W = Y X T ( X X T ) − 1 W = YX^T(XX^T)^{-1} W=YXT(XXT)−1
岭回归
W = a r g m i n ∣ ∣ W X − Y ∣ ∣ 2 2 + β ∣ ∣ W ∣ ∣ 2 2 W=argmin ||WX-Y||^2_2+\beta ||W||^2_2 W=argmin∣∣WX−Y∣∣22+β∣∣W∣∣22解得 W = Y X T ( X X T + β I ) − 1 W = YX^T(XX^T+\beta I)^{-1} W=YXT(XXT+βI)−1
奇异值分解
X
=
U
Λ
V
X = U\Lambda V
X=UΛV其中,
U
∈
R
M
×
K
,
V
∈
R
K
×
N
,
Λ
∈
R
K
×
K
,
K
=
m
i
n
{
M
,
N
}
U\in R^{M\times K}, V\in R^{K\times N}, \Lambda \in R^{K\times K}, K=min\{M,N\}
U∈RM×K,V∈RK×N,Λ∈RK×K,K=min{M,N}。因此,
W
=
Y
V
T
Λ
−
1
U
T
W=YV^T\Lambda^{-1}U^T
W=YVTΛ−1UT为了避免
X
X
X 有特征值为 0 使得
Λ
=
d
i
a
g
{
λ
i
}
\Lambda=diag\{\lambda_i\}
Λ=diag{λi} 不可逆:
Λ
−
1
=
d
i
a
g
{
λ
i
λ
i
2
+
ϵ
}
\Lambda^{-1}=diag\{\frac{\lambda_i}{\lambda_i^2+\epsilon}\}
Λ−1=diag{λi2+ϵλi}其中
ϵ
\epsilon
ϵ 为非常小的正数(
1
0
−
6
10^{-6}
10−6)。
实验
import numpy as np
import time
def MSE(X,Y):
return np.mean((X-Y)**2)
def solve_WX_Y(X,Y, method='SVD', reg=1e-6):
if method == 'SVD': # SVD, accurate method
U, s, V = np.linalg.svd(X, full_matrices=False);
# print('reconstruction error:',MSE(X,U.dot(np.diag(s)).dot(V)))
s = s / (s ** 2 + reg)
W = Y.dot(np.multiply(V.T, np.expand_dims(s, 0)).dot(U.T));
else: # NormalEquation, fast method
B = Y.dot(X.T)
A = X.dot(X.T)
W = np.linalg.solve((A + np.eye(A.shape[0], A.shape[1]) * reg), B.T).T
return W
L = 1000
M = 5000
N = 5000
X = np.random.random((M,N))
W = np.random.random((L,M))
Y = np.dot(W,X)+np.random.random((L,N))*0.01
lb = 0
print('---------SVD---------')
time_start=time.time()
W_solve = solve_WX_Y(X,Y,method='SVD',reg=lb)
time_end=time.time()
print('time:',time_end-time_start)
print('mse :',MSE(W,W_solve))
print('--------Rigde--------')
time_start=time.time()
W_solve = solve_WX_Y(X,Y,method='Ridge', reg=lb)
time_end=time.time()
print('time:',time_end-time_start)
print('mse :',MSE(W,W_solve))
'''
---------SVD---------
time: 0.5515260696411133
mse : 2.5135123133318373e-08
--------Rigde--------
time: 0.12566399574279785
mse : 2.513512313268334e-08
'''