原理
这几天元旦天气针不戳,吃饱喝足开始上班!言归正传,前一篇讲解LU&LDV分解的时候我们说到LU分解可以用来简化Ax=b形式线性方程组的求解过程,大致思路参考上一篇:
r
a
w
:
A
X
=
b
d
o
:
A
=
L
U
g
e
t
:
L
U
X
=
b
i
n
t
r
o
d
u
c
e
Y
:
L
Y
=
b
s
o
l
v
e
:
Y
b
a
c
k
:
U
X
=
Y
s
o
l
v
e
:
X
\begin{aligned} raw:&AX=b \\ do:&A=LU \\ get:&LUX=b\\ introduce Y:&LY=b\\ solve:& Y\\ back:& UX=Y\\ solve: &X \end{aligned}
raw:do:get:introduceY:solve:back:solve:AX=bA=LULUX=bLY=bYUX=YX
教材例子
我们来看一下教材中的例子,可以更好理解整个过程以及帮助大家理顺代码思路,已经熟悉计算过程的bb可以直接跳过啦~
设AX=b
A
=
[
1
−
3
7
2
4
−
3
−
3
7
2
]
b
=
[
2
−
1
3
]
A=\begin{bmatrix} 1 & -3 & 7 \\ 2 & 4 & -3 \\ -3 & 7 & 2 \end{bmatrix} \, b=\begin{bmatrix} 2 \\ -1 \\ 3 \end{bmatrix}
A=⎣⎡12−3−3477−32⎦⎤b=⎣⎡2−13⎦⎤
对A进行LU分解:
A
=
L
U
=
[
1
0
0
2
1
0
−
3
−
1
/
5
1
]
[
1
−
3
7
0
10
−
17
0
0
98
/
5
]
A=LU=\begin{bmatrix} 1 & 0 & 0 \\ 2 & 1 & 0 \\ -3 & -1/5 & 1 \end{bmatrix} \begin{bmatrix} 1 & -3 & 7 \\ 0 & 10 & -17 \\ 0 & 0 & 98/5 \end{bmatrix}
A=LU=⎣⎡12−301−1/5001⎦⎤⎣⎡100−31007−1798/5⎦⎤
则LY=b为:
{
y
1
=
2
2
y
1
+
y
2
=
−
1
−
3
y
1
−
1
5
y
2
+
y
3
=
3
\left\{\begin{matrix} y_1&=&2 \\ 2y_1+y_2&=&-1 \\ -3y_1-\cfrac{1}{5}y_2+y_3&=&3 \end{matrix}\right.
⎩⎪⎪⎨⎪⎪⎧y12y1+y2−3y1−51y2+y3===2−13
解得Y为:
Y
=
[
2
−
5
8
]
Y=\begin{bmatrix} 2 \\ -5 \\ 8 \end{bmatrix}
Y=⎣⎡2−58⎦⎤
回代UX=Y为
{
x
1
−
3
x
3
+
7
x
3
=
2
10
x
2
−
17
x
3
=
−
5
98
5
x
3
=
8
\left\{\begin{matrix} x_1-3x_3+7x_3&=&2 \\ 10x_2-17x_3 & = &-5 \\ \cfrac{98}{5}x_3&=&8 \end{matrix}\right.
⎩⎪⎪⎨⎪⎪⎧x1−3x3+7x310x2−17x3598x3===2−58
自下往上回代可以解得X为:
X
=
[
−
27
98
,
19
98
,
20
49
]
T
≈
[
−
0.27551
,
0.19387
,
0.40816
]
T
X=\begin{bmatrix} -\cfrac{27}{98}, \cfrac{19}{98} , \cfrac{20}{49} \end{bmatrix}^T\approx\begin{bmatrix} -0.27551, 0.19387 , 0.40816 \end{bmatrix}^T
X=[−9827,9819,4920]T≈[−0.27551,0.19387,0.40816]T
代码实现
对应在上一篇的代码中的code如下,cm此时即是U矩阵。
# line 116
if mode == 'LU':
if solve_equation:
return operator_list, resort_m, L, cm
按照上面例子的思路,可以编写代码如下:
@classmethod
def Flip(self, M):
"""Flip Matrix
Args:
M ([np.darray]): input matrix
Returns:
[np.darray]: [result]
"""
assert len(M.shape)==2
if M.shape[1] > 1:
return np.fliplr(np.flipud(M))
else:
return np.flipud(M)
@classmethod
def Solve_Ax_b_Equation(self, A, b, mode='LU', test=False):
"""Function to Solve Ax_b_like Equation using LU decomposition by Junno
Args:
A ([np.darray]): [A]
b ([np.darray]): [b]
mode ([string]): LU for standard solver, LSM for Least-square-method for incompatible equations, Gaussian for the Gaussian-elimination-method
test ([bool]): show checking information, default to False
Returns:
[np.darray]: answer [x]
Last edited: 22-01-02
Author: Junno
"""
assert len(A.shape) == len(b.shape)
assert A.shape[0] == b.shape[0]
if mode == "LU":
# Check that whether the ranks of A and B are the same or not
if self.Get_row_rank(A) != self.Get_row_rank(np.concatenate((A, b), axis=1)):
raise ValueError(
"This matrix equation is incompatible, maybe you can try the Least Square Method ^_^")
else:
# do LU factorization on A
operator_list, trans_m, L, U = self.LDV_2D(
A, mode="LU", Test=False, solve_equation=True)
# print(operator_list)
if test:
print("L matrix from LU:")
print(L)
print("U matrix from LU:")
print(U)
# transform b according with L if there are row exchanges on A
for i in range(len(operator_list)):
op = operator_list[i]
if op[0] == 't':
m, n = op[1]
b[[m, n], :] = b[[n, m], :]
if test:
print('resort b: \n', b)
Y = np.zeros_like(b)
X = np.zeros_like(b)
# solve LY=b
for i in range(L.shape[0]):
temp = 0.
for j in range(i):
temp += Y[j, 0]*L[i, j]
Y[i, 0] = (b[i, 0]-temp)/L[i, i]
if test:
print("Calculate intermediate param Y")
print(Y)
# up-down filp U and Y
flip_U = self.Flip(U)
flip_Y = self.Flip(Y)
if test:
print("Filp U:")
print(flip_U)
print("Filp Y:")
print(flip_Y)
# check zero in diag(flip_U)
if 0 in np.diag(flip_U):
if test:
print(np.diag(flip_U))
raise ValueError("Can't solve this equations")
else:
for i in range(flip_U.shape[0]):
temp = 0.
for j in range(i):
temp += X[j, 0]*flip_U[i, j]
X[i, 0] = (flip_Y[i, 0]-temp)/flip_U[i, i]
# filp x back to right ordering
return self.Flip(X)
elif mode == "LSM":
# 未完待续
return
elif mode == "Gaussian":
# 未完待续
return
example
>>> A=np.array([[1,-3,7],[2,4,-3],[-3,7,2]]).reshape((3,3)).astype(np.float32)
>>> A
array([[ 1., -3., 7.],
[ 2., 4., -3.],
[-3., 7., 2.]], dtype=float32)
>>> b=np.array([[2,-1,3]]).reshape((3,1)).astype(np.float32)
>>> b
array([[ 2.],
[-1.],
[ 3.]], dtype=float32)
>>> X=Matrix_Solutions.Solve_Ax_b_Equation(A,b,"LU")
>>> X
array([[-0.276],
[ 0.194],
[ 0.408]], dtype=float32)
未完待续
这两篇里都有一个检查矩阵秩的函数,起着重要的判断作用,下回我们来介绍怎么实现这个功能~