SWA:随机权值平均(Stochastic Weight Averaging)
每次学习率循环结束时产生的局部最小值趋向于再损失面的边缘域
通过对这几个这样的点取平均,很有可能得到一个更低损失的全局化的通用解
此trick不牺牲inference latency
https://arxiv.org/pdf/1803.05407.pdf
torch中lr scheduler 设置如下 :
optimizer = optim.SGD(net.parameters(), lr=initial_lr, momentum=momentum, weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=3, T_mult=1, eta_min=1e-5, last_epoch=-1)
lr= optimizer.param_groups[-1][‘lr’]
scheduler.step(epoch + iteration + 1 / max_iter)
import os
import torch
def main():
model_dir = 'model/path'
save_dir = 'swa_model/resnet18'
if not os.path.exists(save_dir):
os.makedirs(save_dir)
model_names = [2,5,8,11] #设置12 epoch,提取lr=0时 epoch
model_dirs = [
os.path.join(model_dir, 'checkpoint-epoch' + str(i) + '.pth')
for i in model_names
]
print('model_dirs',model_dirs)
models = [torch.load(model_dir) for model_dir in model_dirs]
model_num = len(models)
model_keys = models[-1]['state_dict'].keys()
state_dict = models[-1]['state_dict']
new_state_dict = state_dict.copy()
ref_model = models[-1]
for key in model_keys:
sum_weight = 0.0
for m in models:
sum_weight += m['state_dict'][key]
avg_weight = sum_weight / model_num
new_state_dict[key] = avg_weight
ref_model['state_dict'] = new_state_dict
save_model_name = 'checkpoint-best.pth'
save_dir = os.path.join(save_dir, save_model_name)
torch.save(ref_model, save_dir)
print('Model is saved at', save_dir)
if __name__ == '__main__':
main()