radam+lookahead optimizer
torch 中 optimizer的类
optimizer的类中的link optimizer详解
主要说明optimizer里面的param_groups这个属性,param_group是一个list,在里面存放着字典 dict,dict中的key包括下面的几个部分:
lookahead 源码解析
lookahead 使用的是向前k步,然后倒退一步,这可以理解为,向前探索的过程中可能会遇到死胡同,这时候有一个小伙伴站在了你俩中间某个位置等待,你找不到出路时可以回去找他。
class Lookahead(Optimizer):
def __init__(self, optimizer, alpha=0.5, k=6):
# alpha 控制小伙伴距离我有多远,1时紧跟我,0时在原点等待我
if not 0.0 <= alpha <= 1.0:
raise ValueError(f'Invalid slow update rate: {alpha}')
if not 1 <= k:
raise ValueError(f'Invalid lookahead steps: {k}')
self.optimizer = optimizer
self.param_groups = self.optimizer.param_groups
self.alpha = alpha
self.k = k
for group in self.param_groups:
group["step_counter"] = 0
# 小伙伴对我的权重进行复制,保持requires_grad = False(detach就能完成)
self.slow_weights = [
[p.clone().detach() for p in group['params']]
for group in self.param_groups]
for w in it.chain(*self.slow_weights):
w.requires_grad = False
self