pytorch的torchvision中给出了很多经典的预训练模型,模型的参数和权重都是在ImageNet数据集上训练好的
加载模型
方法一:直接使用预训练模型中的参数
import torchvision.models as models
model = models.resnet18(pretrained = True) #pretrained设为True,表示使用在ImageNet上训练好的参数
方法二:使用本地磁盘上的参数(直接下载的pth文件或者是在自己数据集上训练好的参数)
import torchvision.models as models
model = models.resnet18(pretrained = False) #pretrained设为False
state_dict = torch.load('resnet18.pth') #使用本地磁盘上的模型参数文件
model.load_state_dict(state_dict) #把读入的模型参数加载到模型中
修改模型
因为预训练模型是在ImageNet数据集上训练的,而ImageNet一共有1000个类别,如果我们要训练的数据集只有20个类别,这时就需要修改模型的全连接层
import torchvision.models as models
model = models.resnet18(pretrained=True)
num_classes = 20 #自己的数据集的类别
inchannel = model.fc.in_features
model.fc = nn.Linear(inchannel, num_classes) #修改全连接层
总结
import torchvision.models as models
model = models.resnet18(pretrained=True)
for p in model.parameters():
p.requires_grad = False #设为False表示只训练最后全连接层的权重,其余层不训练
num_classes = 20 #自己的数据集的类别
model.conv1 = nn.Conv2d(14, 64, kernel_size=7, stride=2, padding=3, bias=False)#改输入通道数
inchannel = model.fc.in_features
model.fc = nn.Linear(inchannel, num_classes) #修改全连接层
model = nn.DataParallel(model).cuda() #gpu训练
model.eval() #排除BN和Dropout的影响,测试的时候加