torch中的乘法符号(*),torch.mm()和torch.matmul(),torch.mul(), torch.bmm()

前言

torch中常见的一些矩阵乘法和元素乘积,说白了无非就是以下四种,为了避免忘了,做个笔记

  1. 乘法符号 *
  2. torch.mul()
  3. torch.mm
  4. torch.matmul
  5. torch.dot

1. 对比

  1. 乘法符号*
# shape=(2,5)
node = tensor([[1, 1, 1, 0, 0], [1, 1, 1, 1, 1]])

node_0 = node.unsqueeze(-1)
'''
tensor([[[1],
         [1],
         [1],
         [0],
         [0]],
        [[1],
         [1],
         [1],
         [1],
         [1]]], torch.Size([2,1,5]))
'''
node_1 = node.unsqueeze(1)
'''
tensor([[[1, 1, 1, 0, 0]],
        [[1, 1, 1, 1, 1]]], torch.Size([2,5,1]))
'''

'True, 值相同'
node_mask.unsqueeze(-1) * node_mask.unsqueeze(1) == /
node_mask.unsqueeze(-1) * node_mask.unsqueeze(1)

print(node_mask.unsqueeze(-1) * node_mask.unsqueeze(1))
'shape=[2,5,5]'

所以,乘法符号是对应的tensor和元素乘。

2.torch.mul()

和上面一样,不同是有官方解释

torch.mul(input, value, out=None)

用标量值value乘以输入input的每个元素,并返回一个新的结果张量。 out=tensor∗value

# shape=(2,5)
node = tensor([[1, 1, 1, 0, 0], [1, 1, 1, 1, 1]])
node_0 = node.unsqueeze(-1) # [2,5,1]
node_1 = node.unsqueeze(1) # [2,1,5]
'shape=[2,5,5]'
torch.mul(node_0, node_1) #
'True,看来是相同的'
torch.mul(node_0, node_1) ==  node_mask.unsqueeze(-1) * node_mask.unsqueeze(1)
  1. torch.matmul()和torch.mm()和torch.bmm()
  • 在矩阵的情况下,矩阵就是shape=n×m格式,只能用torch.mm()torch.matmul()不可以使用torch.bmm()
  • 但是大多数情况都是带有batch_size,也就是sahpe=[batch_size, n, m],只能用torch.matmul()torch.bmm()
# shape=(2,5)
node = tensor([[1, 1, 1, 0, 0], [1, 1, 1, 1, 1]])
node_0 = node.unsqueeze(-1) # [2,5,1]
node_1 = node.unsqueeze(1) # [2,1,5]

'先看mm和matmul'
torch.mm(node_mask, node_mask.transpose(0,1)) # [2,2]
torch.matmul(node_mask, node_mask.transpose(0,1)) # [2,2]
# 下面 True
torch.mm(node_mask, node_mask.transpose(0,1)) == /
torch.matmul(node_mask, node_mask.transpose(0,1))

# error:  torch.mm()就不可以在这种三维tensor下用
torch.mm(node_0, node_1) # 报错
'而torch.matmul()'
torch.matmul(node_0, node_1)


'看高维Tensor'
torch.matmul(node_0, node_1) # --> [2,1,1]
torch.bmm(node_0, node_1) # --> [2,1,1]
torch.bmm(node_mask.unsqueeze(1), node_mask.unsqueeze(-1)) # --> [2,1,1]

'True'
torch.matmul(node_0, node_1) == torch.bmm(node_0, node_1) 

注意:dot很快运算速度远超于matmul

  • 3
    点赞
  • 14
    收藏
    觉得还不错? 一键收藏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值