Python 下 Numba 加速实现比 Numpy.linalg.det() 更快速的行列式计算

本文展示了如何利用numba库优化计算任意大小矩阵行列式的过程,实现比numpy.linalg.det()快约十倍的效率。通过编写特定的计算代码并应用numba的jit加速,实现了与numpy和pytorch计算结果一致但速度显著提升的效果。
摘要由CSDN通过智能技术生成

简介

本文描述一种使用numba加速任意大小矩阵的行列式,实现比用numpy线性代数库中的函数numpy.linalg.det()快约十倍的计算速度。

背景

一项目中需要频繁计算形状为(Batch,K, K)大小矩阵的行列式,于是使用numpy线性代数库中的函数numpy.linalg.det(), 示例如下:

b = 256**2
k = 3
A = np.random.randn(b,k,k)
npDet = np.linalg.det(A)
print("The Numpy Determinant of A is", round(sum(npDet),9))

输出:
The Numpy Determinant of A is 176.93690024

简单使用timeit测试一下速度:

%timeit npDet = np.linalg.det(A)

输出:
10.8 ms ± 90 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

平均10ms的执行速度实在是有点浪费生命,人生苦短,不如自己造轮子。

再造一个更快的轮子:

接下来我们重写一个计算行列式的代码并使用numba.jit()进行加速:

@numba.jit(nopython=True)
def determinant(A):
    """A shape should be (B,n,n) return shape (B,)  """
    B,n,_ = A.shape
    AM = A.copy()
    # Section 2: Row ops on A to get in upper triangle form
    product = np.ones(B,dtype=A.dtype)
    for b in range(B):
        for fd in range(n): # A) fd stands for focus diagonal
            for i in range(fd+1,n): # B) only use rows below fd row
                if AM[b][fd][fd] == 0: # C) if diagonal is zero ...
                    AM[b][fd][fd] = 1.0e-18 # change to ~zero
                # D) cr stands for "current row"
                crScaler = AM[b][i][fd] / AM[b][fd][fd] 
                # E) cr - crScaler * fdRow, one element at a time
                for j in range(n): 
                    AM[b][i][j] = AM[b][i][j] - crScaler * AM[b][fd][j]
        # Section 3: Once AM is in upper triangle form ...
        for i in range(n):
            # ... product of diagonals is determinant
            product[b] *= AM[b][i][i] 
 
    return product

测试并与numpy和pytorch的函数比较一下:

b = 256**2
k = 3
A = np.random.randn(b,k,k)
Det = determinant(A)
npDet = np.linalg.det(A)
torch_Det = torch.linalg.det(torch.Tensor(A))
print("Determinant of A is", round(sum(Det),9))
print("The Numpy Determinant of A is", round(sum(npDet),9))
print("The Torch Determinant of A is", round(sum(npDet),9))
%timeit Det = determinant(A)
%timeit npDet = np.linalg.det(A)
%timeit torch_Det = torch.linalg.det(torch.Tensor(A))

结果如下:
Determinant of A is -397.275054658
The Numpy Determinant of A is -397.275054658
The Torch Determinant of A is -397.275054658
1.73 ms ± 10.4 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
10.8 ms ± 90 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
3.73 ms ± 381 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
可以发现,计算结果与函数库相同,保证计算结果正确。同时实现比numpy快约十倍,比pytorch快约一倍

结语:

为了不浪费各位的生命,就不在这里放使用C实现的版本了,相信numba版本的实现已经可以满足大部分使用场景。

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值