11年it研发经验,从一个会计转行为算法工程师,学过C#,c++,java,android,php,go,js,python,CNN神经网络,四千多篇博文,三千多篇原创,只为与你分享,共同成长,一起进步,关注我,给你分享更多干货知识!
mxnet优化器 sgd_gc代码:
原文:https://github.com/mnikitin/Gradient-Centralization
调用代码:
import optimizer
opt_params = {'learning_rate': 0.001}
sgd_gc = optimizer.SGDGC(gc_type='gc', **opt_params)
sgd_gcc = optimizer.SGDGC(gc_type='gcc', **opt_params)
adam_gc = optimizer.AdamGC(gc_type='gc', **opt_params)
adam_gcc = optimizer.AdamGC(gc_type='gcc', **opt_params)
python3 mnist.py --optimizer sgdgc --gc-type gc --lr 0.1 --seed 42
python3 mnist.py --optimizer adamgc --gc-type gcc --lr 0.001 --seed 42
import mxnet as mx
__all__ = []
def _register_gc_opt():
optimizers = dict()
for name in dir(mx.optimizer):
obj = getattr(mx.optimizer, name)
if hasattr(obj, '__base__') and obj.__base__ == mx.optimizer.Optimizer:
optimizers[name] = obj
suffix = 'GC'
def __init__(self, gc_type='gc', **kwargs):
assert gc_type.lower() in ['gc', 'gcc']
self.gc_ndim_thr = 1 if gc_type.lower() == 'gc' else 3
super(self.__class__, self).__init__(**kwargs)
def update(self, index, weight, grad, state):
self._gc_update_impl(
index, weight, grad, state,
super(self.__class__, self).update)
def update_multi_precision(self, index, weight, grad, state):
self._gc_update_impl(
index, weight, grad, state,
super(self.__class__, self).update_multi_precision)
def _gc_update_impl(self, indexes, weights, grads, states, update_func):
# centralize gradients
if isinstance(indexes, (list, tuple)):
# multi index case: SGD optimizer
for grad in grads:
if len(grad.shape) > self.gc_ndim_thr:
grad -= grad.mean(axis=tuple(range(1, len(grad.shape))), keepdims=True)
else:
# single index case: all other optimizers
if len(grads.shape) > self.gc_ndim_thr:
grads -= grads.mean(axis=tuple(range(1, len(grads.shape))), keepdims=True)
# update weights using centralized gradients
update_func(indexes, weights, grads, states)
inst_dict = dict(
__init__=__init__,
update=update,
update_multi_precision=update_multi_precision,
_gc_update_impl=_gc_update_impl,
)
for k, v in optimizers.items():
name = k + suffix
inst = type(name, (v, ), inst_dict)
mx.optimizer.Optimizer.register(inst)
globals()[name] = inst
__all__.append(name)
_register_gc_opt()
if __name__ == '__main__':
import optimizer
# opt_params = {'learning_rate': 0.001}
# sgd_gc = optimizer.SGDGC(gc_type='gc', **opt_params)
# sgd_gcc = optimizer.SGDGC(gc_type='gcc', **opt_params)
# adam_gc = optimizer.AdamGC(gc_type='gc', **opt_params)
# adam_gcc = optimizer.AdamGC(gc_type='gcc', **opt_params)