WN的理解
1、计算梯度 grads
2、获得lr,如果需要衰减就对应衰减
3、迭代次数更新 t=iterations+1
4、计算本次迭代的lr,lr_t = lr*sqrt(1-pow(beta_2,t))/(1-pow(beta_1,t))
初始状态
params:初始值
grads:根据p和loss计算得到
ms:初始为0,不断更新
vs:初始为0,不断更新
根据parms、grads计算g和V
因为 W = (g/||V||)*V,其中V_scaler = g/||V||,则W = V_scaler*V
V_scaler = g/||V||,初始化为1,也是不断更新的
W已知,V_scaler也已知,可得到 V = W/V_scaler
根据V计算||V||
再根据V_scaler和||V||,得到g=V_scaler*||V||
计算g和V的梯度,根据论文里边的公式,用到grads
用Adam方法更新g和V
用新的g和V更新W
用V计算||V||
V_scaler = g/||V||得到更新
W = V_scaler*V
更新g、v、p