class UpBlock_attention(nn.Module):
def forward(self, fuse, out, fuse_a, out_a):
假如模块有四个输出,我需要勾取UpBlock_attention模块的输入,其中在模型中定义了self.trans3=UpBlock_attention,如果直接注册:
input_list = []
output_list = []
def forward_hook(model, input_data, output_data):
input_list.append(input_data)
output_list.append(output_data)
model.trans3.register_forward_hook(forward_hook)
在取其中的张量时候会报错:ValueError: only one element tensors can be converted to Python scalars。
for i in range(len(input_list)):
input_list_tensor = torch.tensor(input_list[i])
tensor_threewei = input_list_tensor.squeeze(0)
正确做法:在注册钩子函数时就直接定义只有一个输入:
input_list.append(input_data[1])
这样在遍历时不会报错。