一个例子教会你使用pytorch的Hook函数
什么是hook? pytorch提供的一个函数,Hook函数能获取网络中的中间变量,包括forward和backward中的输入输出,如:weight、grad、loss等。
为什么叫hook? Hook函数机制是不改变函数主体,实现额外功能,像一个挂件,挂钩。
使用场景? 进行网络调试的时候,我们希望打印网络中的中间变量Tesnor(网络中涉及到的Tensor),所以就出现了Hook函数。我们可以使用Hook函数优雅地(在不更改网络结构代码的情况下)打印网络中的grad、weight、以及每层的输入输出结果。
PyTorch提供的Hook函数(这里仅为示例,不止四种):
1、torch.Tensor.register_hook(hook_func) //获取Tensor
2、torch.nn.Module.register_forward_hook(hook_func) //获取前向过程中的中间变量
3、torch.nn.Module.register_forward_pre_hook(hook_func) //获取前向过程前的变量
4、torch.nn.Module.register_backward_hook(hook_func) //获取反向过程中的中间变量
下面是一个使用Hook(register_forward_hook)函数的例子:
- 首先定义一个网络结构
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image
import torch.nn.functional as F
class LeNet(nn.Module):
def __init__(self):
super(LeNet,self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool1 = nn.MaxPool2d(2,2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.pool2 = nn.MaxPool2d(2, 2)
self.fc1 = nn.Linear(16*5*5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = F.relu(self.conv1(x))
x = self.pool1(x)
x = F.relu(self.conv2(x))
x = self.pool2(x)
x = x.view(-1, 16*5*5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
- 针对上述定义的网络,介绍注册函数register_forward_hook和自定义函数hook_fn_forward的使用方法(打印指定层conv2的前向输出)
# 定义 forward hook function,输入是module,将其前向传播过程中,
# module的输入输出存到全局变量total_feat_in和total_feat_out中,方便后续调用或查看
total_feat_out = []
total_feat_in = []
def hook_fn_forward(module, input, output):
print("打印层:",module) # 用于区分模块
print(module, '层的forward input:', input) # 首先打印出来
print(module, '层的forward output:', output)
total_feat_in.append(input)# 然后分别存入全局 list 中
total_feat_out.append(output)
#定义模型
model=LeNet()
# 对conv2层注册hook
modules = model.named_children()
for name, module in modules:
if name == "conv2":
module.register_forward_hook(hook_fn_forward)
#module.register_full_backward_hook(hook_fn_backward)
#构造输入,跑一遍forward前向传播
fake_img = torch.randn(4, 3, 32, 32)
output = model(fake_img)
上述代码运行结果如下:
- 另外几种hook函数使用方式也类似,以下是一些简单的示例,大家可以作为参考
def grad_tensor_hook(grad):
print("-----grad tensor:", grad)
return grad
def hook_fn_backward(module, grad_input, grad_output):
print(module) # 为了区分模块
# 为了符合反向传播的顺序,我们先打印 grad_output
print('grad_output', grad_output)
# 再打印 grad_input
print('grad_input', grad_input)
def parameter_hook(module,incompatible_keys):
print("parameter module:", module)
print("incompatible_keys:",incompatible_keys)
modules = model.named_children() #
for name, module in modules:
module.register_full_backward_hook(hook_fn_backward)
module.register_load_state_dict_post_hook(parameter_hook)