pytorch 加载预训练模型

pytorch的torchvision中给出了很多经典的预训练模型,模型的参数和权重都是在ImageNet数据集上训练好的

加载模型
方法一:直接使用预训练模型中的参数

import torchvision.models as models
model = models.resnet18(pretrained = True) #pretrained设为True,表示使用在ImageNet上训练好的参数

方法二:使用本地磁盘上的参数(直接下载的pth文件或者是在自己数据集上训练好的参数)

import torchvision.models as models
model = models.resnet18(pretrained = False) #pretrained设为False
state_dict = torch.load('resnet18.pth') #使用本地磁盘上的模型参数文件
model.load_state_dict(state_dict) #把读入的模型参数加载到模型中

修改模型
因为预训练模型是在ImageNet数据集上训练的,而ImageNet一共有1000个类别,如果我们要训练的数据集只有20个类别,这时就需要修改模型的全连接层

import torchvision.models as models
model = models.resnet18(pretrained=True)
num_classes = 20 #自己的数据集的类别
inchannel = model.fc.in_features
model.fc = nn.Linear(inchannel, num_classes) #修改全连接层

总结

import torchvision.models as models
model = models.resnet18(pretrained=True)
for p in model.parameters(): 
    p.requires_grad = False #设为False表示只训练最后全连接层的权重,其余层不训练
num_classes = 20 #自己的数据集的类别
model.conv1 = nn.Conv2d(14, 64, kernel_size=7, stride=2, padding=3, bias=False)#改输入通道数
inchannel = model.fc.in_features
model.fc = nn.Linear(inchannel, num_classes) #修改全连接层
model = nn.DataParallel(model).cuda() #gpu训练
model.eval() #排除BN和Dropout的影响,测试的时候加
  • 1
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值