昇思MindSpore进阶教程--Per-sample-gradients

大家好,我是刘明,明志科技创始人,华为昇思MindSpore布道师。
技术上主攻前端开发、鸿蒙开发和AI算法研究。
努力为大家带来持续的技术分享,如果你也喜欢我的文章,就点个关注吧

正文开始

计算per-sample-gradients是指计算一个批量样本中每个样本的梯度。在训练神经网络时,很多深度学习框架会计算批量样本的梯度,并利用批量样本的梯度更新网络参数。per-sample-gradients可以帮助我们在训练神经网络时,更准确地计算每个样本对网络参数的影响,从而更好地提高模型的训练效果。

在很多深度学习计算框架中,计算per-sample-gradients是一件很麻烦的事情,因为这些框架会直接累加整个批量样本的梯度。利用这些框架,我们可以想到一个简单的方法来计算per-sample-gradients,即计算批量样本中的每一个样本的预测值和标签值的损失,并计算该损失关于网络参数的梯度,但这个方法显然是很低效的。

MindSpore为我们提供了更高效的方法来计算per-sample-gradients。

我们以TD(0)(Temporal Difference)算法为例对计算per-sample-gradients的高效方法进行说明。TD(0)是一种基于时间差分的强化学习算法,它可以在没有环境模型的情况下学习最优策略。在TD(0)算法中,会根据当前的奖励,对值函数的估计值进行更新,TD(0)算法公式如下,
V ( S t ) = V ( S t ) + α ( R t + 1 + γ V ( S t + 1 ) − V ( S t ) ) V(S_{t}) = V(S_{t}) + \alpha (R_{t+1} + \gamma V(S_{t+1}) - V(S_{t})) V(St)=V(St)+α(Rt+1+γV(St+1)V(St))
通过不断地使用TD(0)算法更新值函数估计值,可以逐步学习到最优策略,从而使在环境中获得的奖励最大化。

在MindSpore中,将jit,vmap和grad组合在一起,我们可以得到更高效的方法来计算per-sample-gradients。

下面对该方法进行介绍,假设在状态 s t s_{t} st时的估计值 v θ v_{\theta} vθ由一个线性函数进行参数化。

from mindspore import ops, Tensor, vmap, jit, grad


value_fn = lambda theta, state: ops.tensor_dot(theta, state, axes=1)
theta = Tensor([0.2, -0.2, 0.1])

考虑如下场景,从状态 s t s_{t} st转换到状态 s t + 1 s_{t+1} st+1,且在这个过程中,我们观察到的奖励为 r t + 1 r_{t+1} rt+1

s_t = Tensor([2., 1., -2.])
r_tp1 = Tensor(2.)
s_tp1 = Tensor([1., 2., 0.])

参数 θ {\theta} θ的更新量的计算公式为:
Δ θ = ( r t + 1 + v θ ( s t + 1 ) − v θ ( s t ) ) ∇ v θ ( s t ) \Delta{\theta}=(r_{t+1} + v_{\theta}(s_{t+1}) - v_{\theta}(s_{t}))\nabla v_{\theta}(s_{t}) Δθ=(rt+1+vθ(st+1)vθ(st))vθ(st)
我们给出伪损失函数 L ( θ ) L(\theta) L(θ)在MindSpore中的实现,

def td_loss(theta, s_tm1, r_t, s_t):
    v_t = value_fn(theta, s_t)
    target = r_tp1 + value_fn(theta, s_tp1)
    return (ops.stop_gradient(target) - v_t) ** 2

将td_loss传入grad中,计算td_loss关于theta的梯度,即theta的更新量。

td_update = grad(td_loss)
delta_theta = td_update(theta, s_t, r_tp1, s_tp1)
print(delta_theta)

td_update仅根据一个样本,计算td_loss关于参数
的梯度,我们可以使用vmap对该函数进行矢量化,它会对所有的inputs和outputs添加一个批处理维度。现在,我们给出一批量的输入,并产生一批量的输出,输出批量中的每个输出元素都对应于输入批量中相应的输入元素。

batched_s_t = ops.stack([s_t, s_t])
batched_r_tp1 = ops.stack([r_tp1, r_tp1])
batched_s_tp1 = ops.stack([s_tp1, s_tp1])
batched_theta = ops.stack([theta, theta])

per_sample_grads = vmap(td_update)
batch_theta = ops.stack([theta, theta])
delta_theta = per_sample_grads(batched_theta, batched_s_t, batched_r_tp1, batched_s_tp1)
print(delta_theta)

在上面的例子中,我们需要手动地为per_sample_grads传递一批量的theta,但实际上,我们可以仅传入单个的theta,为了实现这一点,我们对vmap传入参数in_axes,在in_axes中,参数theta对应的位置被设置为None,其他参数对应的位置被设置为0。这使得我们仅需向除theta以外的其他参数添加一个额外的轴。

inefficiecient_per_sample_grads = vmap(td_update, in_axes=(None, 0, 0, 0))
delta_theta = inefficiecient_per_sample_grads(theta, batched_s_t, batched_r_tp1, batched_s_tp1)
print(delta_theta)

到这里,已经可以正确地计算每个样本的梯度了,但是我们还可以让计算过程变得更快些,我们使用jit调用inefficiecient_per_sample_grads,这会将inefficiecient_per_sample_grads编译为一张可调用的MindSpore图,这会提升它的运行效率。

efficiecient_per_sample_grads = jit(inefficiecient_per_sample_grads)
delta_theta = efficiecient_per_sample_grads(theta, batched_s_t, batched_r_tp1, batched_s_tp1)
print(delta_theta)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

明志刘明

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值