机器学习的Tricks:随机权值平均(Stochastic Weight Averaging,SWA)

SWA:随机权值平均(Stochastic Weight Averaging)
每次学习率循环结束时产生的局部最小值趋向于再损失面的边缘域
通过对这几个这样的点取平均,很有可能得到一个更低损失的全局化的通用解
此trick不牺牲inference latency
https://arxiv.org/pdf/1803.05407.pdf
在这里插入图片描述
在这里插入图片描述

torch中lr scheduler 设置如下 :
optimizer = optim.SGD(net.parameters(), lr=initial_lr, momentum=momentum, weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=3, T_mult=1, eta_min=1e-5, last_epoch=-1)
lr= optimizer.param_groups[-1][‘lr’]
scheduler.step(epoch + iteration + 1 / max_iter)

import os
import torch

def main():
    model_dir = 'model/path'
    save_dir = 'swa_model/resnet18'
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    model_names = [2,5,8,11]  #设置12 epoch,提取lr=0时 epoch
    model_dirs = [
        os.path.join(model_dir, 'checkpoint-epoch' + str(i) + '.pth')
        for i in model_names
    ]
    print('model_dirs',model_dirs)
    models = [torch.load(model_dir) for model_dir in model_dirs]
    model_num = len(models)
    model_keys = models[-1]['state_dict'].keys()
    state_dict = models[-1]['state_dict']
    new_state_dict = state_dict.copy()
    ref_model = models[-1]

    for key in model_keys:
        sum_weight = 0.0
        for m in models:
            sum_weight += m['state_dict'][key]
        avg_weight = sum_weight / model_num
        new_state_dict[key] = avg_weight
    ref_model['state_dict'] = new_state_dict
    save_model_name = 'checkpoint-best.pth'
    save_dir = os.path.join(save_dir, save_model_name)
    torch.save(ref_model, save_dir)
    print('Model is saved at', save_dir)
if __name__ == '__main__':
    main()
    
    
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值