pytorch 的 hook 函数分为 torch.Tensor 和 torch.nn.Module 两类, 分别对应
-
torch.Tensor.register_hook
-
torch.nn.Module.register_backward_hook
一般用第1个比较多。
每一个tensor
都有register_hook
方法,每次当关于这个参数的gradient
被计算出来以后都会调用这个方法,因此可以用于debug
等等,下面是对一部分梯度进行mask
。
def _emb_hook(self, grad):
return grad * Variable(self.grad_mask.unsqueeze(1)).type_as(grad)
def set_grad_mask(self, mask):
self.grad_mask = torch.from_numpy(mask)
self.embedding.weight.register_hook(self._emb_hook)