PyTorch使用动态图技术,下面是一个简单线性变换计算图的搭建实例。
import torch w=torch.tensor([1.], requires_grad=True) x=torch.tensor([2.], requires_grad=True) a=torch.add(w, x) b=torch.add(w, 1) y=torch.mul(a, b) a.retain_grad() b.retain_grad() # 保存非叶子节点的梯度,否则方向传播完会被释放掉 y.backward() # 反向传播 print(w.grad) print(x.grad) # 查看叶子节点、梯度、计算方法 print('is_leaf:\n', w.is_leaf, x.is_leaf, a.is_leaf, b.is_leaf, y.is_leaf) print('gradient:\n', w.grad, x.grad, a.grad, b.grad, y.grad) print('grad_fn:\n', w.grad_fn, x.grad_fn, a.grad_fn, b.grad_fn, y.grad_fn)