RuntimeError: einsum() operands do not broadcast with remapped shapes [original->remapped]: [4, 64, 7, 72, 72]->[4, 64, 1, 72, 72, 7] [3, 3, 64]->[1, 64, 3, 1, 1, 3]解决方法
问题描述:
out = torch.einsum('bcdhw,dkc->bckhw', [input, self.adaptive_align_weights])
在运行上行代码的时候报了标题的错误,表面上看起来好像是维度不匹配,所以我找了一晚上input和self.adaptive_align_weights的维度问题(T-T)但是都没有解决(当然解决不了,因为维度没有问题啊😡)
解决办法:
einsum()中定义的 ‘bcdhw,dkc->bckhw’ ,要保证箭头左边相同的字母维度是一致的。例如在我的定义式中,两个输入的维度d和c要保持一致。我的程序报错就是因为我的input是(4, 64, 7, 128, 128),self.adaptive_align_weights是(3, 3, 64),它们俩的d一个是7,一个是3,不匹配就报错了!