目录
1 tensor的概念
- 关于torch.Tensor
- torch.Tensor是整个package中的核心类,如果将.require_grad设置为True,它将追踪在这个类上定义的所有操作,当代吗要进行反向传播时,直接调用..backward()就可以自动计算所有的梯度,在这个Tensor上的所有梯度将被累加进属性.grad中。
- 如果想要终止一个Tensor在计算图中的追踪回溯,只需要执行.detach()就可以将该Tensor从计算图中撤下,在未来的回溯计算中也不会继续计算此张量。
- 除了.detach(),如果想要终止对计算图的回溯,也就是不再进行方向传播求导的过程,也可以用代码块的方式with torch.no_grad(),这种方式非常适用于对模型进行预测的时候,因为预测阶段不再需要对梯度进行计算。
- 关于torch.Function
- Function类是和Tensor类同等重要的一个核心类,它和tensor共同构建了一个完整的类,每个tensor拥有一个.grad_fn属性,代表引用了哪个具体的Function创建了该Tensor
- 如果某个张量Tensor是用户自定义的,则其对应的grad_fn is None