pytorch中矩阵乘法运算总结

1. torch.mm()

  • 功能: 用于两个2D密集矩阵的矩阵乘法(严格的矩阵乘法,不适用于向量或更高维度的张量。
  • 输入: 仅支持 2D 张量,即两个矩阵的维度必须为 [m, n][n, p],结果的维度为 [m, p]
  • 适用场景: 纯 2D 矩阵的乘法计算。

2. torch.sparse.mm()

  • 功能: 只支持稀疏矩阵(sparse matrix)和密集矩阵(dense matrix)之间的矩阵乘法,输入的第一个矩阵必须是稀疏矩阵。
  • 输入:
    • 第一个参数是稀疏矩阵(如 COO 格式的稀疏矩阵)
    • 第二个参数是密集矩阵/稀疏矩阵
  • 适用场景: 当处理大规模稀疏数据(如图神经网络的邻接矩阵)时,用于减少内存开销和加速计算。
a = torch.randn(2, 3).to_sparse()
b = torch.randn(3, 2)
y = torch.sparse.mm(a, b)
'''
a:
tensor(indices=tensor([[0, 0, 0, 1, 1, 1],
                       [0, 1, 2, 0, 1, 2]]),
       values=tensor([ 0.2511, -1.2641,  0.8681, -2.5626, -0.3517, -0.1242]),
       size=(2, 3), nnz=6, layout=torch.sparse_coo)
b:
tensor([[ 1.2155, -2.2681],
        [-0.0799, -1.3815],
        [-1.7399, -1.3790]])
y:
tensor([[-1.1041, -0.0203],
        [-2.8706,  6.4692]])
'''

3. torch.mul

  • 功能: 用于张量的元素级乘法,即两个张量的相同位置上的元素相乘。按元素进行乘法运算,要求两个张量的形状相同或广播兼容。
  • 输入: 可以是任意形状的张量,只要它们的形状相同或能通过广播机制匹配。
  • 适用场景: 用于逐元素操作,例如标量与张量相乘,或两个形状相同的张量逐元素相乘。
a = torch.tensor([[1, 2], [3, 4]])
b = torch.tensor([[5, 6], [7, 8]])
result = torch.mul(a, b)  
# result = [[1*5, 2*6], [3*7, 4*8]] = [[5, 12], [21, 32]]

4. torch.matmul

  • 功能: 支持多种场景,包括:

    • 1D 向量与 1D 向量相乘(内积)
    • 2D 矩阵与 2D 矩阵相乘(类似 torch.mm
    • 2D 矩阵与 1D 向量相乘(矩阵和向量乘法)
    • 高维张量的矩阵乘法,支持广播
  • 输入: 1D、2D 或更高维度的张量,PyTorch 会根据张量的维度自动选择合适的乘法操作。

  • 适用场景: 用于各种矩阵乘法场景,比 torch.mm 更通用,可以处理高维度的张量。

# Vector dot product (1D)
a = torch.tensor([1, 2])
b = torch.tensor([3, 4])
result = torch.matmul(a, b)  # Dot product, result = 1*3 + 2*4 = 11

# 2D matrix multiplication (like torch.mm)
a = torch.randn(2, 3)
b = torch.randn(3, 4)
result = torch.matmul(a, b)  # Shape will be (2, 4)

# 3D tensor multiplication with broadcasting
a = torch.randn(10, 3, 4)
b = torch.randn(4)
result = torch.matmul(a, b)  # Shape will be (10, 3), broadcasting the 4D vector

5. torch.bmm
与torch.mm类似,但多了一个batch_size维度,矩阵张量1维度是(b×m×n),矩阵张量2维度是(b×n×p),输出维度为(b×m×p)

a = torch.randn(10, 3, 4)
b = torch.randn(10, 4, 5)
result = torch.matmul(a, b)  # Shape will be (10, 3, 5)

结合批量分析 torch.matmul中

# 3D tensor multiplication with broadcasting
a = torch.randn(10, 3, 4)
b = torch.randn(4)

3D的第一个维度当作是批量,即a有10个(3, 4)的矩阵,b是1D,可以视为行向量或者列向量均可,在进行矩阵乘法axb时,需要满足a的列数=b的行数,广播机制后可以进行运算,size为(10,3)

  • 4
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值