import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import ConnectionPatch
X_list = []
def Kaczmarz(A:np.ndarray, b:np.ndarray, x:np.ndarray, u=1, eps=0.001):
if not (len(A.shape) == 2 and len(b.shape) == 1):
exit(-1)
r, c = A.shape
l = b.shape[0]
if l != r or r > c:
exit(-1)
global X_list
X_list.append(x)
while np.linalg.norm(A@x -b) >= eps:
for i in range(r):
x = x + u * (b[i] - A[i]@x)*(A[i]/(A[i]@A[i]))
X_list.append(x)
def depict(A, b, x_list):
x_list = np.asarray(x_list)
r, _ = A.shape
fig, ax = plt.subplots()
x = np.linspace(-2, 8, num=1000)
for i in range(r):
y = 1.0 * (b[i] - A[i, 0] * x) / A[i, 1]
ax.plot(x, y)
plt.scatter(x_list[:, 0], x_list[:, 1])
for a in x_list:
plt.scatter(a[0], a[1])
for i in range(1, x_list.shape[0]):
p1 = x_list[i-1]
p2 = x_list[i]
plt.plot([p1[0], p2[0]], [p1[1], p2[1]])
con = ConnectionPatch(p1, (p2[0] - (p2[0] - p1[0]) / 5, p2[1] - (p2[1] - p1[1]) / 5), "data", "data",
arrowstyle="->", shrinkA=5, shrinkB=5,
mutation_scale=10, fc="w")
ax.add_artist(con)
plt.show()
if __name__ == '__main__':
A = np.array([[1, -1],
[0, 1]])
b = np.array([2, 3])
x = np.array([0, 0])
Kaczmarz(A, b, x)
depict(A, b, X_list)