本博客是阅读eat pytorch in 20 day第二章的个人笔记
动态计算图
计算图由节点和边组成,节点是张量和函数,边表示依赖关系。动态的含义是,前向传播时每一步会立即得到计算结果,反向传播后计算图会立即销毁。
function同时包含正向计算和反向传播的逻辑,比如relu函数:
class MyReLU(torch.autograd.Function):
#正向传播逻辑,可以用ctx存储一些值,供反向传播使用。
@staticmethod
def forward(ctx, input):
ctx.save_for_backward(input)
return input.clamp(min=0)
#反向传播逻辑
@staticmethod
def backward(ctx, grad_output):
input, = ctx.saved_tensors
grad_input = grad_output.clone()
grad_input[input < 0] = 0
return grad_input
relu = MyReLU.apply
只有叶子节点的梯度才会被存到.grad属性里,其余节点的梯度只在计算中出现,并不保存。
.retain_grad()非叶子节点梯度保存、register_hook非叶子节点梯度显示。
import torch
#正向传播
x = torch.tensor(3.0,requires_grad=True)
y1 = x + 1
y2 = 2*x
loss = (y1-y2)**2
#非叶子节点梯度显示控制
y1.register_hook(lambda grad: print('y1 grad: ', grad))
y2.register_hook(lambda grad: print('y2 grad: ', grad))
loss.retain_grad()
#反向传播
loss.backward()
print("loss.grad:", loss.grad)
print("x.grad:", x.grad)
计算图在TensorBoard中的可视化