Error size mismatch, m1: [1 x 1903104], m2: [523776 x 128] at …\aten\src\TH/generic/THTensorMath.cpp:41)
问题原因:
网络的参数对应不上,对应的最后一层(fc1层),经过排查发现是Image输入网络时没有Resize成图片train时的大小(256*256)
解决方法:
在transform中加一条transform.resize(256)即可。
Ref: https://discuss.pytorch.org/t/size-mismatch-m1-16-x-4096-m2-1024-x-3-at-pytorch-aten-src-th-generic-thtensormath-cpp-41/88730
问题复现:
input = torch.randn(16, 4096)</