RuntimeError: mat1 and mat2 shapes cannot be multiplied (3584x7 and 25088x4096)
使用VGG19提取图像特征时出现该问题
报错代码
output = self.features(x) # 输出维度为(512*7*7)
output = self.avgpool(output)
output = self.classifier(output)
原因分析
卷积层的输入为四维[batch_size,channels,H,W] ,而全连接层接受维度为2的输入,通常为[batch_size, size]
解决方案
在全连接层前加入维度变化
使用torch.flatten()
output = self.features(x) # 输出维度为(512*7*7)
output = self.avgpool(output)
output = torch.flatten(output, 1)
output = self.classifier(output)
还看到一种解决方案
x.view(-1,7* 7* 1024)