- 该问题很可能是由浮点数精度引起
- 如果使用的numpy数据的dtype是np.float64,然后转换成pytorch的FloatTensor,就会导致该错误。
示例:
In [32]: img = torch.from_numpy(np.random.randn(1, 3, 10, 224, 224)) 此时数据类型是浮点数64位的
In [33]: conv2(img)
更改方案
img = torch.from_numpy(np.random.randn(1, 3, 10, 224, 224).astype(np.float32))