PyTorch载入预训练权重方法和冻结权重方法_xception的预训练权重pytorch

如果这里的pretrain_weights与我们训练的网络不同,一般指的是包含大于模型参数时,可以修改为

net.load_state_dict(torch.load(pretrain_weights_path), strict=False)

2. 修改网络结构

常用方法1:

model_weight_path = "resnet34pre.pth"
net.load_state_dict(torch.load(model_weight_path))
# 这里假设最后一层为FC层,使用迁移学习,将分类结果修改
# net是实例化的resnet网络,in\_features是网络输入结构参数,最后的5是修改的输出参数
inchannel = net.fc.in_features
net.fc = nn.Linear(inchannel, 5)
# 注意,最后去转换我们的模型设备,否则可能会报错,怀疑是修改的模型部分和原模型部分使用的设备不同
net.to(device) 

常用方法2:

net = MobileNetV2(num_class=5)
net_weights = net.state_dict()
model_weights_path = "./mobilenet\_v2pre.pth"
pre_weights = torch.load(model_weights_path)
# delete classifier weights
# 这种方法主要是遍历字典,.pth文件(权重文件)的本质就是字典的存储
# 通过改变我们载入的权重的键值对&#x
  • 2
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值