module不存在的原因是因为可能预训练模型使用一个显卡训练,而我们自己训练的模型是多卡训练的,这时在加载模型的过程中就会出现module不存在的报错,解决方法直接上代码:
#create model
model = vgg(model_name=vgg16”,num_classes=5).to(device)
# load model weights
weights_path =./vgg16Net.pth
new_state = {}
state_dict = torch.load(weights_path, map_location=device)
for key,value in state_dict.items():
new_state[key.replace('module. ',' ' )]=value
model.load_state_dict(new_state)
model.eval()
with torch.no_grad():
#predict class
output = torch.squeeze(model(img.to(device))).cpu