register_hook是用来对tensor的grad进行操作的
在CAM中有遇到 register_hook
import torch v = torch.ones((1, 3), dtype=torch.float32, requires_grad=True) y = v**2 z = y.sum() y.register_hook(lambda grad: print(grad)) z.backward() print(v.grad)
Pytorch tensor.register_hook(对梯度grad进行操作)
于 2022-02-13 22:02:40 首次发布
关键词由CSDN通过智能技术生成