这个错误信息表明在执行张量运算时,两个张量在非单一维度(这里是维度2)上的尺寸不匹配。具体来说,张量a
在维度2上的大小是2048,而张量b
在相同维度上的大小却是1000。
报错代码行为:
similarity = torch.sum(seq_norm * spec_norm, dim=2)
这是seq_norm和spec_norm做点积的时候要求其向量维度一致
seq_norm是a,seqc_norm是b,于是查看ab两个的大小,发现a是torch.Size([1, 4567, 2048]) b是torch.Size([1000])
b出了问题,找到b的出处,b是用resnet152提取的特征向量,维度为1000是因为直接保存了全连接层的结果(resnet的输出结果就是1000个类别)
要想匹配维度,需要修改b的保存,使b保存全连接层的前一层,就可以解决。
def forward(self, x):
x = self.transform(x)
x = x.unsqueeze(0).to(device)
x = Variable(x)
# 获取 'avgpool' 层的输出
res_features = self.resnet.layer4(x)
res_features = F.avg_pool2d(res_features, res_features.size()[3]) # 全局平均池化
res_features = res_features.squeeze(3).squeeze(2) # 去除不必要的维度
return res_features