使用Pytorch运行时,出现下述错误:RuntimeError: Expected 4-dimensional input for 4-dimensional weight 64 3 3, but got 3-dimensional input of size [3, 224, 224] instead
解决方案:
img=Image.open(imgpath) # 读取图片
img=img.resize((TARGET_IMG_SIZE, TARGET_IMG_SIZE))
tensor=img_to_tensor(img) # 将图片转化成tensor,
print(tensor.shape) #[3, 224, 224]
tensor = Variable(torch.unsqueeze(tensor, dim=0).float(), requires_grad=False)
print(tensor.shape) #[1,3, 224, 224]
tensor=tensor.cuda()
补充:
from torchvision import models, transforms
VGG = models.vgg16(pretrained=True)
feature = torch.nn.Sequential(*list(VGG.children())[:])
print(feature)
print('=============')
print(VGG._modules.keys())# 查看包含哪两部分
exit()