torch.matmul(a,b)处理的一般是a和b的最后两个维度,假设a的维度为B*F*M,b也为B*F*M, 在对a,b做相乘操作的时候,需要完成对B的维度顺序的变换,通过permute(0, 2, 1)变换为B*M*F。
通过变换后进行torch.matmul(a,b)得到结果为B*F*F,在除了最后两个维度的的之前维度上都被认为是Batch。
示例1:
>>> import torch
>>> a=torch.rand((1000,5,10))
>>> b=torch.rand((1000,10,12))
>>> c=torch.matmul(a,b)
>>> c.shape
torch.Size([1000, 5, 12])
在处理不同维度时,会通过广播来合并除最后两个维度外的其他维度,如对于A*B*F*M与B*M*F的matmul,结果为A*B*F*F
示例2:
>>> a=torch.rand((50,1000,5,10))
>>> b=torch.rand((1000,10,12))
>>> c=torch.matmul(a,b)
>>> c.shape
torch.Size([50, 1000, 5, 12])