import torch
# 创建一个需要梯度计算的FloatTensor
a = torch.tensor([2., 3.], dtype=torch.float, requires_grad=True)
# 计算 b = a * 3
b = a * 3
# 计算 c = b * b * 3
c = b * b * 3
# 计算 c 的均值
out = c.mean()
# 反向传播,计算梯度
out.backward()
在反向传播阶段,out.backward()
会计算out
关于a
的梯度。考虑到前面的运算步骤,我们需要计算out
关于a
的偏导数,即dc/da
。
根据链式法则,我们可以分解这个过程:
c = (a*3)^2 * 3
,首先计算(a*3)
,记为b
,然后计算b
的平方,并乘以3。- 因此,
dc/db = 2 * b * 3
,然后db/da = 3
。 - 所以,
dc/da = dc/db * db/da = 2 * b * 3 * 3
,因为b
在计算梯度时已经是关于a
的表达式,所以b
的值会被代入。