tensorflow中tf.raw_ops.ApplyCenteredRMSProp()函数的使用

1. 功能

居中RMSProp算法使用居中第二矩(即方差)的估计值进行归一化,而普通RMSProp则使用(非居中)第二矩。这通常有助于训练,但在计算和内存方面略显昂贵。
需要注意的是,在这个算法的密集实现中,即使grad为零,mg、ms和mom也会更新,但在这个稀疏实现中,mg、ms和mom在grad为零的迭代中不会更新。
mean_square = decay * mean_square + (1-decay) * gradient ** 2
mean_grad = decay * mean_grad + (1-decay) * gradient
Delta = learning_rate * gradient / sqrt(mean_square + epsilon - mean_grad ** 2)
mg <- rho * mg_{t-1} + (1-rho) * grad
ms <- rho * ms_{t-1} + (1-rho) * grad * grad
mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms - mg * mg + epsilon)
var <- var - mom

2. 参数

var:可变的 Tensor 。必须为以下类型之一: float32 , float64 , int32 , uint8 , int16 , int8 , complex64 , int64 , qint8 , quint8 , qint32 , bfloat16 , uint16 , complex128 , half , uint32 , uint64 。应该来自Variable()。
mg:可变的 Tensor 。必须具有与 var 相同的类型。应该来自Variable()。
ms:可变的 Tensor 。必须具有与 var 相同的类型。应该来自Variable()。
mom:可变的 Tensor 。必须具有与 var 相同的类型。应该来自Variable()。
lr:学习率,一个 Tensor 。必须具有与 var 相同的类型。必须是Tensor。
rho:衰减率一个 Tensor 。必须具有与 var 相同的类型。衰减率。必须是Tensor。
momentum:动量,一个 Tensor 。必须与 var 具有相同的类型。动量标度。必须是Tensor。
epsilon:一个 Tensor 。必须具有与 var 相同的类型。岭学期。必须是Tensor。
grad:一个 Tensor 。必须具有与 var 相同的类型。渐变。
use_locking:可选的 bool 。默认为 False 。如果为 True ,则通过锁保护var,mg,ms和mom张量的更新;否则,行为是不确定的,但可能会减少争用。
name:操作的名称(可选)。

3. 代码样例

data_type = np.float16
idxs_np = np.random.randint(0, 3, size=3).astype(np.int32)
var = np.random.random(size=(3, 3)).astype(data_type)
mg = np.random.random(size=(3, 3)).astype(data_type)
ms = np.random.random(size=(3, 3)).astype(data_type)
mom = np.random.random(size=(3, 3)).astype(data_type)
grad_np = np.random.rand(*(3, 3)).astype(data_type)
lr = 0.0
decay = 1e-10
momentum = 0.001
epsilon = 0.1

uni_idx, idx = tf.unique(idxs_np)
var_tf = tf.Variable(tf.gather(var, uni_idx, axis=0), validate_shape=False)
ms_tf = tf.Variable(tf.gather(ms, uni_idx, axis=0), validate_shape=False)
mg_tf = tf.Variable(tf.gather(mg, uni_idx, axis=0), validate_shape=False)
mom_tf = tf.Variable(tf.gather(mom, uni_idx, axis=0), validate_shape=False)
grad_tf = tf.Variable(tf.gather(grad_np, uni_idx, axis=0), validate_shape=False)

out = tf.raw_ops.ApplyCenteredRMSProp(var=var_tf, mg=mg_tf, ms=ms_tf, mom=mom_tf, lr=lr, rho=decay, momentum=momentum,
                                      epsilon=epsilon, grad=grad_tf, use_locking=False, name='out')

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    out_ = sess.run(out)
    print(out_)

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

浅蓝的风

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值