建议阅读官方文档说明
该函数对模块注册一个钩子函数。每次模块的forward()函数计算输出后,都会调用该钩子函数。钩子函数必须具有以下函数签名hook(module, input, output)
一个简单的conv
模块调用一个计算参数量的钩子函数如下
def hook(module,input,output):
class_name = str(module.__class__.__name__)
if class_name.find("Conv") != -1 or class_name.find("BatchNorm") != -1 or \
class_name.find("Linear") != -1:
params = 0
for param_ in module.parameters():
params += param_.view(-1).size(0)
print('params',params)
import torch
conv = torch.nn.Conv2d(1,8,(2,3))
input = torch.rand(1,1,224,224) # batch,channel,width,height
hook_handle = conv.register_forward_hook(hook)
output = conv(input)
hook_handle.remove()
output = conv(input)
可以看到,调用output = conv(input)
后,其注册的钩子函数也被一并调用,即计算并打印conv
层的参数量。随后,我们通过hook_handle.remove()
将该钩子函数移除,再次调用output = conv(input)
这一forward()函数,并没有调用钩子函数。因此程序最终只输出了一次调用钩子函数的结果。
params 56