报错代码:
Wh = torch.mm(h, self.w)
报错: RuntimeError: self must be a matrix
原因:torch.mm()是两个矩阵相乘,即两个二维的张量相乘,维度超过二维,则会报错。
这两个tensor的维度是[16, 16, 29]
和[29, 70]
>>> h.shape
torch.Size([16, 16, 29])
>>> self.w.shape
torch.Size([29, 70])
修改:使用torch.matmul()
Wh = torch.matmul(h, self.w)
>>>Wh.shape
torch.Size([16, 16, 70])