matmul torch 详解_pytorch 乘法运算汇总与解析

本文详细介绍了 PyTorch 中的各种乘法运算,包括元素一一相乘(哈达玛积)、向量点乘、矩阵乘法、向量与矩阵相乘、矩阵与向量相乘,以及批量矩阵运算,如 batched matrix multiplication,并对比了 matmul 和 bmm 函数的使用场景和区别。
摘要由CSDN通过智能技术生成

pytorch 有多种乘法运算,在这里做一次全面的总结。

元素一一相乘

该操作又称作 "哈达玛积", 简单来说就是 tensor 元素逐个相乘。这个操作,是通过

也就是常规的乘号操作符定义的操作结果。torch.mul 是等价的。

import torch

def element_by_element():

x = torch.tensor([1, 2, 3])

y = torch.tensor([4, 5, 6])

return x * y, torch.mul(x, y)

element_by_element()

(tensor([ 4, 10, 18]), tensor([ 4, 10, 18]))

这个操作是可以 broad cast 的。

def element_by_element_broadcast():

x = torch.tensor([1, 2, 3])

y = 2

return x * y

element_by_element_broadcast()

tensor([2, 4, 6])

向量点乘

torch.matmul: If both tensors are 1-dimensional, the dot product (scalar) is returned.

如果都是1维的,返回的就是 dot product 结果

def vec_dot_product():

x = torch.tensor([1, 2, 3])

y = torch.tensor([4, 5, 6])

return torch.matmul(x, y)

vec_dot_product()

tensor(32)

矩阵乘法

torch.matmul: If both arguments are 2-dimensional, the matrix-matrix product is returned.

如果都是2维,那么就是矩阵乘法的结果返回。与 torch.mm 是等价的,torch.mm 仅仅能处理的是矩阵乘法。

def matrix_multiple():

x = torch.tensor([

[1, 2, 3],

[4, 5, 6]

])

y = torch.tensor([

[7, 8],

[9, 10],

[11, 12]

])

return torch.matmul(x, y), torch.mm(x, y)

matrix_multiple()

(tensor([[ 58, 64],

[139, 154]]), tensor([[ 58, 64],

[139, 154]]))

vector 与 matrix 相乘

torch.matmul: If the first argument is 1-dimensional and the second argument is 2-dimensional, a 1 is prepended to its dimension for the purpose of the matrix multiply. After the matrix multiply, the prepended dimension is removed.

如果第一个是 vector, 第二个是 matrix, 会在 vector 中增加一个维度。也就是 vector 变成了

与 matrix

相乘之后,变成

, 在结果中将

维 再去掉。

def vec_matrix():

x = torch.tensor([1, 2, 3])

y = torch.tensor([

[7, 8],

[9, 10],

[11, 12]

])

return torch.matmul(x, y)

vec_matrix()

tensor([58, 64])

matrix 与 vector 相乘

同样的道理, vector会被扩充一个维度。

def matrix_vec():

x = torch.tensor([

[1, 2, 3],

[4, 5, 6]

])

y = torch.tensor([

7, 8, 9

])

return torch.matmul(x, y)

matrix_vec()

tensor([ 50, 122])

带有batch_size 的 broad cast乘法

def batched_matrix_broadcasted_vector():

x = torch.tensor([

[

[1, 2], [3, 4]

],

[

[5, 6], [7, 8]

]

])

print(f"x shape: {x.size()} \n {x}")

y = torch.tensor([1, 3])

return torch.matmul(x, y)

batched_matrix_broadcasted_vector()

x shape: torch.Size([2, 2, 2])

tensor([[[1, 2],

[3, 4]],

[[5, 6],

[7, 8]]])

tensor([[ 7, 15],

[23, 31]])

batched matrix x batched matrix

def batched_matrix_batched_matrix():

x = torch.tensor([

[

[1, 2, 1], [3, 4, 4]

],

[

[5, 6, 2], [7, 8, 0]

]

])

y = torch.tensor([

[

[1, 2],

[3, 4],

[5, 6]

],

[

[7, 8],

[9, 10],

[1, 2]

]

])

print(f"x shape: {x.size()} \n y shape: {y.size()}")

return torch.matmul(x, y)

xy = batched_matrix_batched_matrix()

print(f"xy shape: {xy.size()} \n {xy}")

x shape: torch.Size([2, 2, 3])

y shape: torch.Size([2, 3, 2])

xy shape: torch.Size([2, 2, 2])

tensor([[[ 12, 16],

[ 35, 46]],

[[ 91, 104],

[121, 136]]])

上面的效果与 torch.bmm 是一样的。matmul 比 bmm 功能更加强大,但是 bmm 的语义非常明确, bmm 处理的只能是 3维的。

def batched_matrix_batched_matrix_bmm():

x = torch.tensor([

[

[1, 2, 1], [3, 4, 4]

],

[

[5, 6, 2], [7, 8, 0]

]

])

y = torch.tensor([

[

[1, 2],

[3, 4],

[5, 6]

],

[

[7, 8],

[9, 10],

[1, 2]

]

])

print(f"x shape: {x.size()} \n y shape: {y.size()}")

return torch.bmm(x, y)

xy = batched_matrix_batched_matrix()

print(f"xy shape: {xy.size()} \n {xy}")

x shape: torch.Size([2, 2, 3])

y shape: torch.Size([2, 3, 2])

xy shape: torch.Size([2, 2, 2])

tensor([[[ 12, 16],

[ 35, 46]],

[[ 91, 104],

[121, 136]]])

tensordot

这个函数还没有特别清楚。

def tesnordot():

x = torch.tensor([

[1, 2, 1],

[3, 4, 4]])

y = torch.tensor([

[7, 8],

[9, 10],

[1, 2]])

print(f"x shape: {x.size()}, y shape: {y.size()}")

return torch.tensordot(x, y, dims=([0], [1]))

tesnordot()

x shape: torch.Size([2, 3]), y shape: torch.Size([3, 2])

tensor([[31, 39, 7],

[46, 58, 10],

[39, 49, 9]])

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值