迁移学习 pytorch

迁移学习(Transfer Learning)简单来说就是使用一些已经训练好的模型迁移到类似的新的问题进行使用,而不必对新问题重新建模,从头训练和优化参数。这些训练好的模型同时包含了优化好的参数,在使用的时候只需要做微调就可以应用到新问题。

比如我现在需要做一个图像二分类的任务,我可以使用已经训练好的resnet_152模型来进行微调。

首先应该将我们的输入数据处理成resnet同样的输入尺寸224*224

然后通过torchvision下载我们需要的模型:

from torchvision import models
model = models.resnet152(pretrained=True) # 参数pretrained=True表示需要下载预训练好的参数

冻结模型的网络结构,使参数不更新

for param in model.parameters():
    param.requires_grad = False

根据我们需要的输出更改网络结构

model.fc = nn.Linear(model.fc.in_features, 2) #这里模型最后的输出类别为2

后面再正常对我们的模型进行训练

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值