导入pytorch官方的预训练模型:
from model import resnet34, resnet101
载入预训练模型方法:
model_weight_path = "./resnet34-333f7ec4.pth"
missing_keys, unexpected_keys = net.load_state_dict(torch.load(model_weight_path), strict=False)
# for param in net.parameters():
# param.requires_grad = False
# change fc layer structure
inchannel = net.fc.in_features
net.fc = nn.Linear(inchannel, 5)
net.to(device)
载入预训练模型方法有两种,一种直接在net = resnet34()中添加num_class参数如net = resnet34(num_class=5),另外一种使用默认num_class再另外添加一个全连接层net.fc = nn.Linear(inchannel, 5)
另外ResNet的正则化参数不一样:
"val": transforms.Compose([transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}
包括验证集的数据处理采用先将图片更改到256xnxn大小在从中心裁剪224x224大小。