python Kaczmarz 算法实现以及绘图展示

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import ConnectionPatch

X_list = []


def Kaczmarz(A:np.ndarray, b:np.ndarray, x:np.ndarray, u=1, eps=0.001):
    if not (len(A.shape) == 2 and len(b.shape) == 1):
        exit(-1)
    r, c = A.shape
    l = b.shape[0]
    if l != r or r > c:
        exit(-1)
    global X_list
    X_list.append(x)
    while np.linalg.norm(A@x -b) >= eps:
        for i in range(r):
            x = x + u * (b[i] - A[i]@x)*(A[i]/(A[i]@A[i]))
            X_list.append(x)


def depict(A, b, x_list):
    x_list = np.asarray(x_list)
    r, _ = A.shape
    fig, ax = plt.subplots()
    x = np.linspace(-2, 8, num=1000)
    for i in range(r):
        y = 1.0 * (b[i] - A[i, 0] * x) / A[i, 1]
        ax.plot(x, y)
    plt.scatter(x_list[:, 0], x_list[:, 1])
    for a in x_list:
        plt.scatter(a[0], a[1])

    for i in range(1, x_list.shape[0]):
        p1 = x_list[i-1]
        p2 = x_list[i]
        plt.plot([p1[0], p2[0]], [p1[1], p2[1]])
        con = ConnectionPatch(p1, (p2[0] - (p2[0] - p1[0]) / 5, p2[1] - (p2[1] - p1[1]) / 5), "data", "data",
                              arrowstyle="->", shrinkA=5, shrinkB=5,
                              mutation_scale=10, fc="w")
        ax.add_artist(con)
    plt.show()


if __name__ == '__main__':
    A = np.array([[1, -1],
                  [0, 1]])
    b = np.array([2, 3])
    x = np.array([0, 0])
    Kaczmarz(A, b, x)
    depict(A, b, X_list)

  • 1
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值