大家好,我是刘明,明志科技创始人,华为昇思MindSpore布道师。
技术上主攻前端开发、鸿蒙开发和AI算法研究。
努力为大家带来持续的技术分享,如果你也喜欢我的文章,就点个关注吧
正文开始
计算per-sample-gradients是指计算一个批量样本中每个样本的梯度。在训练神经网络时,很多深度学习框架会计算批量样本的梯度,并利用批量样本的梯度更新网络参数。per-sample-gradients可以帮助我们在训练神经网络时,更准确地计算每个样本对网络参数的影响,从而更好地提高模型的训练效果。
在很多深度学习计算框架中,计算per-sample-gradients是一件很麻烦的事情,因为这些框架会直接累加整个批量样本的梯度。利用这些框架,我们可以想到一个简单的方法来计算per-sample-gradients,即计算批量样本中的每一个样本的预测值和标签值的损失,并计算该损失关于网络参数的梯度,但这个方法显然是很低效的。
MindSpore为我们提供了更高效的方法来计算per-sample-gradients。
我们以TD(0)(Temporal Difference)算法为例对计算per-sample-gradients的高效方法进行说明。TD(0)是一种基于时间差分的强化学习算法,它可以在没有环境模型的情况下学习最优策略。在TD(0)算法中,会根据当前的奖励,对值函数的估计值进行更新,TD(0)算法公式如下,
V
(
S
t
)
=
V
(
S
t
)
+
α
(
R
t
+
1
+
γ
V
(
S
t
+
1
)
−
V
(
S
t
)
)
V(S_{t}) = V(S_{t}) + \alpha (R_{t+1} + \gamma V(S_{t+1}) - V(S_{t}))
V(St)=V(St)+α(Rt+1+γV(St+1)−V(St))
通过不断地使用TD(0)算法更新值函数估计值,可以逐步学习到最优策略,从而使在环境中获得的奖励最大化。
在MindSpore中,将jit,vmap和grad组合在一起,我们可以得到更高效的方法来计算per-sample-gradients。
下面对该方法进行介绍,假设在状态 s t s_{t} st时的估计值 v θ v_{\theta} vθ由一个线性函数进行参数化。
from mindspore import ops, Tensor, vmap, jit, grad
value_fn = lambda theta, state: ops.tensor_dot(theta, state, axes=1)
theta = Tensor([0.2, -0.2, 0.1])
考虑如下场景,从状态 s t s_{t} st转换到状态 s t + 1 s_{t+1} st+1,且在这个过程中,我们观察到的奖励为 r t + 1 r_{t+1} rt+1。
s_t = Tensor([2., 1., -2.])
r_tp1 = Tensor(2.)
s_tp1 = Tensor([1., 2., 0.])
参数
θ
{\theta}
θ的更新量的计算公式为:
Δ
θ
=
(
r
t
+
1
+
v
θ
(
s
t
+
1
)
−
v
θ
(
s
t
)
)
∇
v
θ
(
s
t
)
\Delta{\theta}=(r_{t+1} + v_{\theta}(s_{t+1}) - v_{\theta}(s_{t}))\nabla v_{\theta}(s_{t})
Δθ=(rt+1+vθ(st+1)−vθ(st))∇vθ(st)
我们给出伪损失函数
L
(
θ
)
L(\theta)
L(θ)在MindSpore中的实现,
def td_loss(theta, s_tm1, r_t, s_t):
v_t = value_fn(theta, s_t)
target = r_tp1 + value_fn(theta, s_tp1)
return (ops.stop_gradient(target) - v_t) ** 2
将td_loss传入grad中,计算td_loss关于theta的梯度,即theta的更新量。
td_update = grad(td_loss)
delta_theta = td_update(theta, s_t, r_tp1, s_tp1)
print(delta_theta)
td_update仅根据一个样本,计算td_loss关于参数
的梯度,我们可以使用vmap对该函数进行矢量化,它会对所有的inputs和outputs添加一个批处理维度。现在,我们给出一批量的输入,并产生一批量的输出,输出批量中的每个输出元素都对应于输入批量中相应的输入元素。
batched_s_t = ops.stack([s_t, s_t])
batched_r_tp1 = ops.stack([r_tp1, r_tp1])
batched_s_tp1 = ops.stack([s_tp1, s_tp1])
batched_theta = ops.stack([theta, theta])
per_sample_grads = vmap(td_update)
batch_theta = ops.stack([theta, theta])
delta_theta = per_sample_grads(batched_theta, batched_s_t, batched_r_tp1, batched_s_tp1)
print(delta_theta)
在上面的例子中,我们需要手动地为per_sample_grads传递一批量的theta,但实际上,我们可以仅传入单个的theta,为了实现这一点,我们对vmap传入参数in_axes,在in_axes中,参数theta对应的位置被设置为None,其他参数对应的位置被设置为0。这使得我们仅需向除theta以外的其他参数添加一个额外的轴。
inefficiecient_per_sample_grads = vmap(td_update, in_axes=(None, 0, 0, 0))
delta_theta = inefficiecient_per_sample_grads(theta, batched_s_t, batched_r_tp1, batched_s_tp1)
print(delta_theta)
到这里,已经可以正确地计算每个样本的梯度了,但是我们还可以让计算过程变得更快些,我们使用jit调用inefficiecient_per_sample_grads,这会将inefficiecient_per_sample_grads编译为一张可调用的MindSpore图,这会提升它的运行效率。
efficiecient_per_sample_grads = jit(inefficiecient_per_sample_grads)
delta_theta = efficiecient_per_sample_grads(theta, batched_s_t, batched_r_tp1, batched_s_tp1)
print(delta_theta)