hook获取模型中间层参数帮助训练

最近写网络压缩的实验要用到模型的中间层输出来帮助训练,第一次了解到还能这样玩,还能利用模型的中间层参数,感觉又掌握了一项技能。
我所使用的模型如下:

def dwpw_conv(ic, oc, kernel_size=3, stride=2, padding=1):
    return nn.Sequential(
        nn.Conv2d(ic, ic, kernel_size, stride=stride, padding=padding, groups=ic),  # depthwise convolution
        nn.BatchNorm2d(ic),
        nn.LeakyReLU(0.01, inplace=True),
        nn.Conv2d(ic, oc, 1),  # pointwise convolution
        nn.BatchNorm2d(oc),
        nn.LeakyReLU(0.01, inplace=True)
    )


class StudentNet(nn.Module):
    def __init__(self):
        super().__init__()

        # ---------- TODO ----------
        # Modify your model architecture
        # 224 --> 112
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.layer1 = dwpw_conv(64, 64, stride=1)
        self.layer2 = dwpw_conv(64, 128)
        self.layer3 = dwpw_conv(128, 256)
        self.layer4 = dwpw_conv(256, 140)
        # Here we adopt Global Average Pooling for various input size.
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(140, 11)

    def forward(self, x):
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.maxpool(out)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.avgpool(out)
        out = out.flatten(1)
        out = self.fc(out)
        return out

假设我们想要获取layer1的输出,先查看模型中含有的module及名字

net = StudentNet()

for (name, module) in net.named_modules():
    print(name)
    print(module)
 
#部分输出结果展示如下:
conv1
Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
bn1
BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
relu
ReLU(inplace=True)
maxpool
MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
layer1
Sequential(
  (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=64)
  (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): LeakyReLU(negative_slope=0.01, inplace=True)
  (3): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
  (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (5): LeakyReLU(negative_slope=0.01, inplace=True)
)
layer1.0
Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=64)
layer1.1
BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
layer1.2
LeakyReLU(negative_slope=0.01, inplace=True)
layer1.3
Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
layer1.4
BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
layer1.5
LeakyReLU(negative_slope=0.01, inplace=True)

找到你想hook的层,定义hook函数,设置注册hook即可

def hook(module, input, output):
    features.append(output)
    return None
net.layer1.register_forward_hook(hook)

net.layer1.register_forward_hook(hook)
input = torch.randn((64,3,28,28))
output = net(input)
print(features[0].size())   

##输出结果如下
torch.Size([64, 64, 7, 7]) 

这样你就可以获取中间层输出并进一步优化你的算法了。此外,我写的这个hook函数很简单,定义一个list获取中间层的输出很low,并且如果同时要获取多个层的输出就会写的很繁琐,最近刚好看到了一个大佬写的,分享一下。

class HookTool:
    def __init__(self):
        self.fea = None

    def hook_fun(self, module, fea_in, fea_out):
        self.fea = fea_out


def get_feas_by_hook(model, names=['layer1', 'layer2', 'layer3']):
    fea_hooks = []
    for name, module in model.named_modules():
        if name in names:
            cur_hook = HookTool()
            module.register_forward_hook(cur_hook.hook_fun)
            fea_hooks.append(cur_hook)
    return fea_hooks
fea = get_feas_by_hook(net)
input = torch.randn((64,3,28,28))
output = net(input)
print(fea[0].fea.size())

##输出结果
torch.Size([64, 64, 7, 7])

另外此篇博客再补充几个对模型参数的几个操作

#查看模型各个层的参数
for n,p in net.named_parameters():
    print(n)
    print(p)


#设置缓冲区以及提取缓冲区内的参数   

for n, p in net.named_parameters():
    n = n.replace('.', '_')
    net.register_buffer('{}'.format(n), p.detach().clone())
for n, p in net.named_parameters():
    n = n.replace('.', '_')
    p_new = getattr(net, '{}'.format(n))

这些操作可以在life long learning 的regularization中用到。在此记录。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值