问题是你有2D数组但索引并分配它们就像它们是一维数组一样。因此,在将它们传递给numba函数之前,你可以只使用ravel()。我不确定这是否真的正确 - 但为了这个答案的目的,我认为它是。
此外,你不需要复制a和c,因为你不修改它们,你实际上只需要复制b和d的第一个元素。
所以工作函数可能如下所示:
import numba as nb
import numpy as np
@nb.njit
def TDMA(a,b,c,d):
n = len(d)
x = np.zeros(n)
bc = np.zeros(len(b))
bc[0] = b[0]
dc = np.zeros(len(d))
dc[0] = d[0]
for i in range(1, n):
w = a[i - 1] / bc[i - 1]
bc[i] = b[i] - w * c[i - 1]
dc[i] = d[i] - w * dc[i - 1]
x[n - 1] = dc[n - 1] / bc[n - 1]
for k in range(n - 2, -1, -1):
x[k] = (dc[k] - c[k] * x[k + 1]) / bc[k]
return x
你这样称呼它:
TDMA(a.ravel(), b.ravel(), c.ravel(), B.ravel())
因为我使用了ravel(),结果与np.linalg.solve的形状不同:
by default solver, x1 = [[ 3.05427975]
[-0.13569937]
[-0.18789144]
[ 4.03757829]]
by TDMA, x2 = [ 3.05427975 -0.13569937 -0.18789144 4.03757829]
但是我真的不会重新实现NumPy函数,除非你可以利用NumPy函数不知道的数据中的某些结构。 NumPy是一个高性能的库,它已经使用了非常精细的实现,所以偶然的重新实现可能只会对极小的数据集更快,或者你可以利用一些关于你的数据的事实(这允许一个非常高性能的算法) )。
我不得不承认我不知道“三对角矩阵算法”,但我知道一些BLAS libraries(通常是令人难以置信的快速数学库)实现它。 NumPy使用BLAS。
然而,SciPy为特殊矩阵类型提供了一些(非常快速的)特殊线性代数求解器:
inv(a[, overwrite_a, check_finite])计算矩阵的逆矩阵。
solve(a, b[, sym_pos, lower, overwrite_a, …])求解方程a矩阵的未知x的线性方程组a * x = b。
solve_banded(l_and_u, ab, b[, overwrite_ab, …])假设a是带状矩阵,求解方程a x = b表示x。
solveh_banded(ab, b[, overwrite_ab, …])求解方程a x = b。
solve_circulant(c, b[, singular, tol, …])求x对于C x = b,其中C是循环矩阵。
solve_triangular(a, b[, trans, lower, …])假设a是三角矩阵,求解方程a x = b表示x。
solve_toeplitz(c_or_cr, b[, check_finite])使用Levinson Recursion解决Toeplitz系统问题
[...]
关于你与map的问题:现任官方list of supported built-in functions不包括map。所以你不能在Numbas nopython模式中使用map。