关键方法
- _create_slots:为每个待更新变量创建用于计算的关联变量。
- _resource_apply_dense与_resource_apply_sparse:每层梯度更新都会调用该方法,返回更新变量操作。
Adax Optimizer实现代码如下:
import tensorflow as tf
class AdaX(tf.keras.optimizers.Optimizer):
r"""Implements AdaX algorithm.
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float, optional): learning rate (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square (default: (0.9, 1e-4))
eps (float, optional): term added to the denominator to improve
numerical stability (default: 1e-12)
weight_decay (float, optional): L2 penalty (default: 5e-4)
.. _Adam\: A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980
.. _On the Convergence of Adam and Beyond:
https://openreview.net/forum?id=ryQu7f-RZ
"""
def __init__(
self,
learning_rate=0.001,
beta_1=0.9,
beta_2=0.0001,
epsilon=1e-6,
**kwargs
):
kwargs['name'] = kwargs.get('name') or 'AdaX_V2'
super(AdaX, self).__init__(**kwargs)
self._set_hyper('learning_rate', learning_rate)
self._set_hyper('beta_1', beta_1)
self._set_hyper('beta_2', beta_2)
self.epsilon = epsilon
print('self._initial_decay:', self._initial_decay)
def _create_slots(self, var_list):
'''
给变量创建关联变量,用于梯度计算
var_list:可更新的变量列表
'''
tf.print('var_list:', type(var_list))
for var in var_list:
self.add_slot(var, 'm')
self.add_slot(var, 'v')
def _resource_apply(self, grad, var, indices=None):
'''每层梯度更新的计算公式'''
# 准备变量
var_dtype = var.dtype.base_dtype
lr_t = self._decayed_lr(var_dtype)
m = self.get_slot(var, 'm')
v = self.get_slot(var, 'v')
beta_1_t = self._get_hyper('beta_1', var_dtype)
beta_2_t = self._get_hyper('beta_2', var_dtype)
epsilon_t = tf.cast(self.epsilon, var_dtype)
local_step = tf.cast(self.iterations + 1, var_dtype)
# 更新公式
if indices is None:
m_t = m.assign(beta_1_t * m + (1 - beta_1_t) * grad)
v_t = v.assign((1 + beta_2_t) * v + beta_2_t * grad**2)
else:
mv_ops = [
m.assign(beta_1_t * m),
v.assign((1 + beta_2_t) * v)
]
with tf.control_dependencies(mv_ops):
m_t = self._resource_scatter_add(
m, indices, (1 - beta_1_t) * grad
)
v_t = self._resource_scatter_add(
v, indices, beta_2_t * grad**2)
# 返回算子
# tf.control_dependencies先执行前置操作,后执行内部代码
with tf.control_dependencies([m_t, v_t]):
v_t = v_t / (tf.pow(1.0 + beta_2_t, local_step) - 1.0)
var_t = var.assign(var - lr_t * m_t / (tf.sqrt(v_t) + self.epsilon))
return var_t
def _resource_apply_dense(self, grad, var):
'''每层梯度跟新都会调用该方法'''
return self._resource_apply(grad, var)
def _resource_apply_sparse(self, grad, var, indices):
'''每层梯度跟新都会调用该方法'''
return self._resource_apply(grad, var, indices)
def get_config(self):
tf.print('get_config')
config = {
'learning_rate': self._serialize_hyperparameter('learning_rate'),
'beta_1': self._serialize_hyperparameter('beta_1'),
'beta_2': self._serialize_hyperparameter('beta_2'),
'epsilon': self.epsilon,
}
base_config = super(AdaX, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
参考资料:https://github.com/bojone/adax