pytorch一共有5种乘法
- *乘,element-wise乘法,支持broadcast操作
- torch.mul(),和*乘完全一样
- torch.mm(),矩阵叉乘,即对应元素相乘相加,不支持broadcast操作
- torch.bmm(),三维矩阵乘法,一般用于mini-batch训练中
- torch.matmul(),叉乘,支持broadcast操作
先定义下面的tensor(本文不展示print结果):
import torch
tensorA_2x3 = torch.tensor(
[[1,2,3],
[3,2,1]]
)
tensorB_1x3 = torch.tensor(
[[1,2,3]]
)
tensorC_scalar = 5
tensorD_2x1 = torch.tensor(
[[2],
[