hook函数
为了节省显存(内存),PyTorch会自动舍弃图计算的中间结果,所以想要获取这些数值就需要使用hook函数。hook函数在使用后应及时删除(remove),以避免每次都运行钩子增加运行负载。
这里总结一下并给出实际用法和注意点。hook方法有4种:
1、Tensor.register_hook()
2、torch.nn.Module.register_forward_hook()
3、torch.nn.Module.register_backward_hook()
4、torch.nn.Module.register_forward_pre_hook()

1.Tensor.register_hook(hook):对于单个张量,可以使用register_hook()方法注册一个hook。该方法将一个函数(即hook)注册到张量上,在张量被计算时调用该函数。这个函数可以用来获取张量的梯度或值,或者对张量进行其他操作。例如,以下代码演示了如何使用register_hook()方法获取张量的梯度:
import torch
x = torch.randn(2, 2, requires_grad=True)
def print_grad(grad):
print(grad)
hook_handle = x.register_hook(print_grad)
y = x.sum()
y.backward()
hook_handle.remove()
在这个例子中,我们创建了一个包含梯度的张量x,并使用register_hook()方法注册了一个打印梯度的hook函数print_grad()。在计算y的梯度时,hook函数被调用并打印梯度。最后,我们使用hook_handle.remove()方法从张量中删除hook函数。

2.torch.nn.Module.register_forward_hook(hook):对于模型中的每个层,可以使用register_forward_hook()方法注册一个hook函数。这个函数将在模型的前向传递中被调用,并可以用来获取层的输出。例如,以下代码演示了如何使用register_forward_hook()方法获取模型的某一层的输出:
import torch.nn as nn
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
def forward(self, x):
x = self.conv1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.relu(x)
x = self.avgpool(x)
return x
def hook(module, input, output):
print(output.shape)
model = MyModel()
handle = model.conv2.register_forward_hook(hook)
x = torch.

最低0.47元/天 解锁文章
1058

被折叠的 条评论
为什么被折叠?



