文章目录
1. 张量的乘法(矩阵的乘法)
API | 说明 |
---|---|
matmul() | 矩阵的乘法 |
@ | 重载运算符 |
多维张量的乘法是如何定义的呢?不考虑前几维,只考虑数据的最后两维。如维度[4, 3, 64, 42] 的张量 @ 维度[4, 3, 42, 128]的张量,结果则为[4, 3, 64, 128]的张量。有时也会用到广播特性,暂不介绍。 |
q1 = torch.rand(4, 3, 42, 64)
q2 = torch.rand(4, 3, 64, 128)
ret = q1@q2
print(ret.shape)
结果为:
torch.Size([4, 3, 42, 128])
API | 说明 |
---|---|
bmm() | batch的矩阵乘法 |
torch.bmm