Pytorch获得网络任意层的特征图

 前言

        该篇文章通过hook方法来获得任意层的特征图。已经将方法封装为类,即拿即用。

一、获得模型各层级的名称和结构

        为了获得任意层的特征图,我们需要知道这个特征图的“索引”。然而,这些索引具有一定的格式,并不是我们可以直观的看出来的,为了找到这些索引与我们模型各层的对应关系,我们需要做以下工作:

1.获得我们直观可读的网络结构。程序示例如下:

save_file = open("model_structure.txt",'w')
# print的第一个参数是你的模型(torch.nn.Module),不需要载入权重
print(your_model_without_weight, file=save_file)
save_file.close()

         模型的结构会被保存到该程序文件夹下的model_structure.txt中,示例如下(以ACVNet部分为例):

DataParallel(
  (module): ACVNet(
    (feature_extraction): feature_extraction(
      (firstconv): Sequential(
        (0): Sequential(
          (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
          (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (1): ReLU(inplace=True)

2.获得网络各层的“索引”

def get_all_layers(model:torch.nn.Module):
    layer_dict = {name:layer for name, layer in model.named_modules()}
    print(layer_dict.keys())

        运行以上函数你将获得网络各层的索引,示例如下(以ACVNet部分为例):

dict_keys(['', 'feature_extraction', 'feature_extraction.firstconv', 'feature_extraction.firstconv.0', 'feature_extraction.firstconv.0.0', 'feature_extraction.firstconv.0.1', 'feature_extraction.firstconv.1',

 3.对比直观结构与“索引”的关系,明确你想要的特征图是由哪一层输出的

        我们可以看到“索引”是由直观结构中括号内的module名称和序号构成的,例如‘feature_extraction.firstconv.0.0’对应的是只管结构中的(0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)一行。

二、通过hook获得任意层的特征图

1.特征图获取类

class FeatureExtractor():
    def __init__(self, model:torch.nn.Module, layer_name:str) -> None:
        self.features_in_hook = None
        self.features_out_hook = None

        layer_dict = {name:layer for name, layer in model.named_modules()}
        assert layer_name != None
        assert layer_name in layer_dict.keys()
        self.layer_name = layer_name
        
        layer_wanted = layer_dict[self.layer_name]
        layer_wanted.register_forward_hook(hook=self.hook)
    
    def hook(self, module, fea_in, fea_out) -> None:
        self.features_in_hook = fea_in[0]       # get layer input.fea_in is a tuple object with shape of (data,)
        self.features_out_hook = fea_out[0]     # get layer output
        return None
    
    def get_feature_in(self):
        return self.features_in_hook
    
    def get_feature_out(self):
        return self.features_out_hook

2.使用示例

        我们以获取一、2中所提及到的卷积层为例,我们获取它的输出特征图,程序如下:

# load model
model = your_model()

# load weight
model.load_state_dict(your_weight)

# get a hook
layer_name = "feature_extraction.firstconv.0.1"
feat_ext = FeatureExtractor(model=model, layer_name=layer_name)

# get feature
model.eval()
output = model(your_input)    #获得特征图之前,需要进行一次前向传播
feature = feat_ext.features_out_hook    #获得特征图

  • 0
    点赞
  • 13
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值