java 并行 矩阵乘法_使用tensordot实现批量矩阵乘法

这篇博客探讨了如何在不支持并行matmul的环境中,利用numpy的tensordot和reshape函数,来实现与np.matmul相同效果的并行矩阵乘法。作者通过示例展示了在小矩阵并行计算中,如何通过调整操作避免迭代第一维,以充分利用BLAS/GPU加速。目前的问题是tensordot的输出不匹配matmul的结果,博主寻求修正tensordot表达式的方法。
摘要由CSDN通过智能技术生成

我试图通过使用tensordot,dot和reshaping等来实现与np.matmul并行矩阵乘法相同的行为 .

我正在将其翻译为使用的库没有支持并行乘法的matmul,只有dot和tensordot .

另外,我想避免迭代第一维,并希望使用一组矩阵乘法和重新整形(希望尽可能多地使用BLAS / GPU运行,因为我有大量小矩阵并行计算) .

这是一个例子:

import numpy as np

angles = np.array([np.pi/4, 2*np.pi/4, 2*np.pi/4])

vectors = np.array([ [1,0],[1,-1],[-1,0]])

s = np.sin(angles)

c = np.cos(angles)

rotations = np.array([[c,s],[-s,c]]).T

print rotations

print vectors

print("Correct: %s" % np.matmul(rotations, vectors.reshape(3,2,1)))

# I want to do this using tensordot/reshaping, i.e just gemm BLAS operations underneath

print("Wrong: %s" % np.tensordot(rotations, vectors, axes=(1,1)))

这个输出是:

Correct: [[[ 7.07106781e-01]

[ 7.07106781e-01]]

[[ 1.00000000e+00]

[ 1.00000000e+00]]

[[ -6.12323400e-17]

[ -1.00000000e+00]]]

Wrong: [[[ 7.07106781e-01 1.11022302e-16 -7.07106781e-01]

[ -7.07106781e-01 -1.41421356e+00 7.07106781e-01]]

[[ 6.12323400e-17 -1.00000000e+00 -6.12323400e-17]

[ -1.00000000e+00 -1.00000000e+00 1.00000000e+00]]

[[ 6.12323400e-17 -1.00000000e+00 -6.12323400e-17]

[ -1.00000000e+00 -1.00000000e+00 1.00000000e+00]]]

有没有一种方法可以修改第二个表达式,以获得与第一个表达式相同的结果,只需使用dot / tensordot .

我相信它是可能的,并且已经看过some comments online,但从来没有任何例子

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值