用pytorch搭建网络测试时,代码报错如下:
Expected object of type torch.DoubleTensor but found type torch.FloatTensor for argument #2 ‘weight’
搭建的网络为AlexNet,测试代码如下
from torchvision import models
model = models.alexnet(pretrained=True)
x = np.random.rand(1,3,224,224)
#x = x.astype(np.float32)
x_ts = torch.from_numpy(x)
x_in = Variable(x_ts)
y = model (x_in)
报类型错误,默认x类型为float64,加上注释那句运行正确。