2.5.4 pytorch中的非标量反向传播
在PyTorch中有个简单的规定,不让张量对张量求导,只允许标量对张量求导。因此,目标量对一个非标量调用backward(),则需要传入一个gradient参数。传入这个参数就是为了把张量对张量的求导转换为标量对张量的求导。
假设目标值loss=(y1,y2,y3…,ym),传入的参数为 v=(v1,v2,v3…,vm),那么就可以把对loss的求导,转换为对loss*vT标量的求导。即把原来loss对x求导的雅克比矩阵乘以张量v的转置,便可得到我们需要的梯度矩阵。
backward的格式为:
backward(gradient=None, retain_graph=None, create_graph=False)
下面看看代码实例:
# 非标量反向传播
import torch
# 1、定义叶子节点及计算节点
# 定义叶子节点张量x
x = torch.tensor([2, 3], dtype=torch.float, requires_grad=True)
# 初始化雅克比矩阵
J = torch.zeros(2, 2)
# 初始化目标张量,形状为1×2
y = torch.zeros(1, 2)
# 定义y与x之间的映射关系
y[0, 0] = x[0] ** 2 + 3 * x[1]
y[0, 1] = x[1] ** 2 + 2 * x[0]
这个我们用手工计算一下,它的梯度为
我们直接调用backward来进行反向传播:
# 非标量反向传播
import torch
# 1、定义叶子节点及计算节点
# 定义叶子节点张量x
x = torch.tensor([2, 3], dtype=torch.float, requires_grad=True)
# 初始化雅克比矩阵
J = torch.zeros(2, 2)
# 初始化目标张量,形状为1×2
y = torch.zeros(1, 2)
# 定义y与x之间的映射关系
y[0, 0] = x[0] ** 2 + 3 * x[1]
y[0, 1] = x[1] ** 2 + 2 * x[0]
# 2、调用backward来获取y对x的梯度
y.backward(torch.Tensor([[1, 1]]))
print(x.grad)
# 结果显然是错误的
结果:tensor([6., 9.])
因为我们现在是张量对张量求导,跟前面说的情况不一样。这里我们可以分成两部进行计算。首先让gradient参数为(1,0)得到y1对x的求导,然后让gradient参数为(0,1),得到y2对x的求导,最后梯度叠加。注意:**这里因为重复使用backward(),需要使参数retain_graph=True,代码如下:
# 3、正确计算张量对张量求导
# 生成y1对x的梯度
y.backward(torch.Tensor([[1, 0]]), retain_graph=True)
J[0] = x.grad
# 梯度是累加的,故需要对x的梯度清零
x.grad = torch.zeros_like(x.grad)
# 生成y2对x的梯度
y.backward(torch.Tensor([[0, 1]]))
J[1] = x.grad
# 显示雅克比矩阵的值
print("J雅克比矩阵的值:{}".format(J))
结果:
J雅克比矩阵的值:tensor([[4., 3.],
[2., 6.]])
于手工计算的结果一致