今天想用hook获取每层输出参数的时候。
调用pytorch中的API,发现没有好的方法能够将每层的层名和参数一对一的打印处理
为此修改了一下hook的源码,再调用就成了。
import torch
import torch.utils.hooks as hooks
def get_output_param(module, datasets):
output_param = {}
def hook(module, input, output):
name = list(module._forward_hooks.keys())[0]
output_param[name] = output.data
for name, layer in module.named_modules():
if name != "":
handle = hooks.RemovableHandle(layer._forward_hooks)
layer._forward_hooks[f"{name}"] = hook
module(datasets)
return output_param
end