pytorch下分类神经网络的迁移学习transfer learning

对预训练模型的迁移引用【1】中的提法,分为两种形式

  1. 只训练最后fc层的freeze and train
  2. 以预训练模型为初始参数,训练所有层的finetune
这里只讨论网络结构的变更

finetune
model_ft = models.resnet18(pretrained=True)
num_ftrs = model_ft.fc.in_features #最后fc层的输入
model_ft.fc = nn.Linear(num_ftrs, NUM_CLASSES) #NUM_CLASSES是自己数据的类别

model_ft = models.vgg16(pretrained=True)
num_ftrs = model_ft.classifier[6].in_features
feature_model = list(model_ft.classifier.children())
feature_model.pop()            
feature_model.append(nn.Linear(num_ftrs, NUM_CLASSES))
model_ft.classifier = nn.Sequential(*feature_model)
如果在基础网络的基础上还要再增加层数,可用【2】中mian.py的方法
num_ftrs = model_ft.fc.in_features
feature_model = list(model_ft.fc.children())
feature_model.append(nn.Linear(num_ftrs, cf.feature_size))
feature_model.append(nn.BatchNorm1d(cf.feature_size))
feature_model.append(nn.ReLU(inplace=True))
feature_model.append(nn.Linear(cf.feature_size, len(dset_classes)))
model_ft.fc = nn.Sequential(*feature_model)
【2】中还提到了,特征提取的方法
if(args.net_type == 'alexnet' or args.net_type == 'vggnet'):
    feature_map = list(checkpoint['model'].module.classifier.children())
    feature_map.pop()
    new_classifier = nn.Sequential(*feature_map)
    extractor = copy.deepcopy(checkpoint['model'])
    extractor.module.classifier = new_classifier
elif (args.net_type) == 'resnet'):
    feature_map = list(model.module.children())
    feature_map.pop()
    extractor = nn.Sequential(*feature_map)

freeze_train 的网络结构在前面的基础上加入
for param in model_conv.parameters(): #params have requires_grad=True by default
        param.requires_grad = False
num_ftrs = model_conv.fc.in_features
model_conv.fc = nn.Linear(num_ftrs, num_class)
以防止在反向传播的过程中,改变前面层的参数

【3】中说明了一下,随着训练的进行,learning_rate应该进行一定的衰减,以免在梯度下降过程中,在接近的时候local optimum的时候错过。
def lr_scheduler(optimizer, epoch, init_lr=0.001, lr_decay_epoch=7):
    """Decay learning rate by a factor of 0.1 every lr_decay_epoch epochs."""
    lr = init_lr * (0.1**(epoch // lr_decay_epoch))

    if epoch % lr_decay_epoch == 0:
        print('LR is set to {}'.format(lr))

    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

    return optimizer

【1】

  • 3
    点赞
  • 15
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值