pytorch优化器传入两个模型的参数/保存和加载两个模型的参数

该博客探讨了在人脸识别任务中如何配置Arcface_loss,并详细说明了如何根据权重更新规则将模型参数分为多个组,以便在训练过程中对它们进行独立管理和更新。博客内容涉及了如何初始化和管理两个模型的优化器,包括对权重衰减的处理,以及如何保存和加载模型参数。此外,还介绍了如何根据参数类型和需求将模型参数分组,确保不同部分的参数得到适当的训练和调整。
摘要由CSDN通过智能技术生成

在人脸识别任务中,当我定义模型backbone后,用到Arcface_loss,但这个Arcface_loss也是用nn.Module模块写的,所以实例化出来也是一个网络,而且原论文中,Arcface_loss还是以backbone权重参数10倍的权重衰减方式更新,需要单独以不同的方式训练,且网络中bn层也是不需要权重衰减的。由于这些原因,我们就需要把网络参数分开。

分别传入两个模型的参数

如果我们只有一个模型model,我们一般就是这样初始化优化器的:

optimizer = torch.optim.SGD(model.parameters(), lr, momentum=0.9)

现在有两个模型model和head了,要把参数都加进优化器,并且以不同lr来训练参数,我们就可以这样来初始化优化器:

optimizer = torch.optim.SGD([{'params': model.parameters()},
                             {'params': head.parameters(), 'lr': 1e-4}],
                             lr=0.01, momentum=0.9)

还可以将整个模型按参数需不要训练更新,权重衰减来分组:

def separate_paras(modules):
    parses_bn, parses_w, parses_bias = [], [], []
    for k, v in modules.named_modules():
        if hasattr(v, 'bias') and isinstance(v.bias, nn.Parameter):
            parses_bias.append(v.bias)  # biases
        if isinstance(v, nn.BatchNorm2d):
            parses_bn.append(v.weight)  # no decay
        elif hasattr(v, 'weight') and isinstance(v.weight, nn.Parameter):
            parses_w.append(v.weight)  # apply decay

    return parses_w, parses_bias, parses_bn

paras_w, paras_bias, paras_bn = separate_paras(model)
optimizer = torch.optim.SGD(paras_bn, lr=args.lr_model, momentum=args.momentum, nesterov=True)  # 只需设置学习率
optimizer.add_param_group({'params': paras_w[:-1], 'weight_decay': args.weight_decay})  # backbone权重参数,需要权重衰减
optimizer.add_param_group({'params': [paras_w[-1]]+[head.kernel], 'weight_decay': args.weight_decay*1.5}) #特殊层,需要单独定义权重衰减参数
optimizer.add_param_group({'params': paras_bias})  # 偏置
del paras_w, paras_bias, paras_bn    # 用完之后清除内存

保存/加载两个模型的参数

  1. 保存, 使用torch.save(state, filename)
filename = './facemodel.pth'
state = {'model':model.state_dict(),
         'head':head.state_dict()}
torch.save(state, filename)
  1. 加载,使用torch.load(),并使用字典进行匹配加载
load_name = './facemodel.pth'
checkpoint = torch.load(load_name)
model.load_state_dict(checkpoint['model'])
head.load_state_dict(checkpoint['head'])
model.cuda()
head.cuda()
PyTorch 中,可以使用 `torch.save()` 函数将训练好的模型保存到磁盘上。该函数需要传入两个参数模型的状态字典和文件名。 以下是保存模型的示例代码: ```python import torch # 假设已经训练好了一个模型保存模型 model_state = model.state_dict() torch.save(model_state, 'model.pth') ``` 加载模型时,可以使用 `torch.load()` 函数将模型状态字典从磁盘中加载出来,并使用 `load_state_dict()` 方法将模型参数加载模型中。以下是加载模型的示例代码: ```python import torch # 加载模型 model = Model() # 这里的 Model 是你定义的模型model_state = torch.load('model.pth') model.load_state_dict(model_state) ``` 在加载模型时,需要确保模型类定义中的参数保存模型状态字典中的参数名称和顺序一致。如果有不一致的地方,可以在加载模型时使用 `strict=False` 参数来禁用严格模式,这样可以忽略一些不一致的参数。 ```python import torch # 加载模型(禁用严格模式) model = Model() # 这里的 Model 是你定义的模型model_state = torch.load('model.pth') model.load_state_dict(model_state, strict=False) ``` 使用加载好的模型进行预测时,只需要将数据传入模型即可。以下是使用模型进行预测的示例代码: ```python import torch # 加载模型 model = Model() # 这里的 Model 是你定义的模型model_state = torch.load('model.pth') model.load_state_dict(model_state) # 使用模型进行预测 input_data = torch.randn(1, 3, 224, 224) # 假设输入数据为 1 张 3 通道的 224x224 图片 output = model(input_data) ```
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值