register_forward_hook的理解


小白理解,希望大佬指正!!!

register_forward_hook(hook)在pytorch中文文档

register_forward_hook(hook)在pytorch中文文档
在module上注册一个forward hook。 每次调用forward()计算输出的时候,这个hook就会被调用。它应该拥有以下签名:

hook(module, input, output) -> None

hook不应该修改 input和output的值。 这个函数返回一个 句柄(handle)。它有一个方法 handle.remove(),可以用这个方法将hook从module移除。

根据代码理解

class LayerActivations:
    features = None
 
    def __init__(self, model, layer_num):
        self.hook = model[layer_num].register_forward_hook(self.hook_fn)
 
    def hook_fn(self, module, input, output):
        self.features = output.cpu()
        #print('11111111111111111111111111')
        #print(module)
 
    def remove(self):
        self.hook.remove()
        #print('22222222222222222222222222')
 
print('vgg.features:',vgg.features)
 
conv_out = LayerActivations(vgg.features, 3)  
img = next(iter(train_loader))[0]
#print('conv_out:',conv_out)
#print('conv_out.hook:',conv_out.hook)
 
# imshow(img)
o = vgg(Variable(img.cuda()))
#print('vgg:',vgg)
conv_out.remove()  
#print('remove')
act = conv_out.features
#print('features')

里面加了一些乱七八糟的#是在自己尝试的时候打印断句,看看这一部分之前的代码到底生成了什么。
这一部分的输出是:

conv_out: <__main__.LayerActivations object at 0x000002B3A58FBC88>
conv_out.hook: <torch.utils.hooks.RemovableHandle object at 0x000002B3AA612D48>
11111111111111111111111111
ReLU(inplace)
vgg: VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (18): ReLU(inplace)
    (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (20): ReLU(inplace)
    (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (22): ReLU(inplace)
    (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (25): ReLU(inplace)
    (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (27): ReLU(inplace)
    (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (29): ReLU(inplace)
    (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
  (classifier): Sequential(
    (0): Linear(in_features=25088, out_features=4096, bias=True)
    (1): ReLU(inplace)
    (2): Dropout(p=0.5)
    (3): Linear(in_features=4096, out_features=4096, bias=True)
    (4): ReLU(inplace)
    (5): Dropout(p=0.5)
    (6): Linear(in_features=4096, out_features=1000, bias=True)
  )
)
22222222222222222222222222
remove
features

也就是说,在 conv_out = LayerActivations(vgg.features,3) 的时候,它完成了hook的建立,然后我理解的是它建立的 hook(module, input, output) 的module就由 model[layer_num].register_forward_hook(self.hook_fn) 指定的 model[layer_num] 决定,告诉了这个hook受到 hook_fn(self, module, input, output) 的input和output的影响,但是因为还没有输入 input和output,所以hook_fn还没有启用,于是没有print出来1111111111111。

当o = vgg(Variable(img.cuda())) ,由图片经过这个插入hook的完整的vgg模型后,指定module部分有了input和output,所以hook_fn启用,print出来1111111111111。

求大佬教导的问题Q

Q1:从vgg模型先后输出的结构来看,vgg模型并没有因为加入hook而在结构上有显式的变化,所以hook的存在只是为了用一种快捷的方法输出指定层的结果么?就是本身是对模型结构没影响的?

Q2:如果hook在模型中没有显式的变化,对模型结构没有影响,那么之后有必要 conv_out.remove() 么?

Q3:

conv_out: <__main__.LayerActivations object at 0x0000020A815035C8>
conv_out.hook: <torch.utils.hooks.RemovableHandle object at 0x0000020A82ED2748>

这两个是什么类型呢?
conv_out是一个指定的类LayerActivations,conv_out.hook是一个指定的句柄?

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值