矩阵维度
首先需要确认多维矩阵每个维度的对应含义。
a = torch.tensor([[[3.], [1.]], [[4.], [0.]], [[2.], [1.]]])
这是一个三维矩阵,他的size是:3,2,1
很容易理解和记忆,最外层[]内的元素数即size的第一个参数,第二层对应第二个参数,以此类推。
那么其中元素对应关系就有了:
a
000
=
3
a_{000}=3
a000=3
a
210
=
1
a_{210}=1
a210=1
整个矩阵就有:
[
[
3
1
]
[
4
0
]
[
2
1
]
]
\begin{bmatrix} \begin{bmatrix}3 \\ 1\end{bmatrix} \begin{bmatrix}4 \\ 0\end{bmatrix} \begin{bmatrix}2 \\ 1\end{bmatrix} \end{bmatrix}
[[31][40][21]]
正常来讲,最内侧的两维可以用矩阵行列表示,外面就只能通过嵌套矩阵表示了。
最内侧为行,次内侧为列。
矩阵乘法
pytorch的矩阵乘法中,主要内容为2维×2维。也就是用torch.mm这个函数。
torch.mm
直接上例子:
import torch
b = torch.tensor([[2., 3.]])
c = torch.tensor([[3.], [1.]])
print(b)
print(c)
out1 = torch.mm(b, c)
print(out1)
# ### output
tensor([[2., 3.]])
tensor([[3.],
[1.]])
tensor([[9.]])
[
2
3
]
×
[
3
1
]
=
[
9
]
\begin{bmatrix}2 \ 3 \end{bmatrix} ×\begin{bmatrix}3 \\ 1\end{bmatrix}=[9]
[2 3]×[31]=[9]
先理清矩阵维度,这个就很容易理解了。
torch.bmm
矩阵批处理乘法。怎么命名不重要,功能就是两批2维矩阵对应相乘。
a = torch.tensor([[[2., 3.], [1., 2.]], [[3., 4.], [0., 5.]]])
w = torch.tensor([[[3.], [1.]], [[2.], [4.]]])
print(a)
print(w)
out = torch.bmm(a, w)
print(out)
tensor([[[2., 3.],
[1., 2.]],
[[3., 4.],
[0., 5.]]])
tensor([[[3.],
[1.]],
[[2.],
[4.]]])
tensor([[[ 9.],
[ 5.]],
[[22.],
[20.]]])
torch.matmul
这个我只知道在二维时与mm结果一样,高维的更复杂,但是没有深入研究。