前言
该篇文章通过hook方法来获得任意层的特征图。已经将方法封装为类,即拿即用。
一、获得模型各层级的名称和结构
为了获得任意层的特征图,我们需要知道这个特征图的“索引”。然而,这些索引具有一定的格式,并不是我们可以直观的看出来的,为了找到这些索引与我们模型各层的对应关系,我们需要做以下工作:
1.获得我们直观可读的网络结构。程序示例如下:
save_file = open("model_structure.txt",'w')
# print的第一个参数是你的模型(torch.nn.Module),不需要载入权重
print(your_model_without_weight, file=save_file)
save_file.close()
模型的结构会被保存到该程序文件夹下的model_structure.txt中,示例如下(以ACVNet部分为例):
DataParallel(
(module): ACVNet(
(feature_extraction): feature_extraction(
(firstconv): Sequential(
(0): Sequential(
(0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(1): ReLU(inplace=True)
2.获得网络各层的“索引”
def get_all_layers(model:torch.nn.Module):
layer_dict = {name:layer for name, layer in model.named_modules()}
print(layer_dict.keys())
运行以上函数你将获得网络各层的索引,示例如下(以ACVNet部分为例):
dict_keys(['', 'feature_extraction', 'feature_extraction.firstconv', 'feature_extraction.firstconv.0', 'feature_extraction.firstconv.0.0', 'feature_extraction.firstconv.0.1', 'feature_extraction.firstconv.1',
3.对比直观结构与“索引”的关系,明确你想要的特征图是由哪一层输出的
我们可以看到“索引”是由直观结构中括号内的module名称和序号构成的,例如‘feature_extraction.firstconv.0.0’对应的是只管结构中的(0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)一行。
二、通过hook获得任意层的特征图
1.特征图获取类
class FeatureExtractor():
def __init__(self, model:torch.nn.Module, layer_name:str) -> None:
self.features_in_hook = None
self.features_out_hook = None
layer_dict = {name:layer for name, layer in model.named_modules()}
assert layer_name != None
assert layer_name in layer_dict.keys()
self.layer_name = layer_name
layer_wanted = layer_dict[self.layer_name]
layer_wanted.register_forward_hook(hook=self.hook)
def hook(self, module, fea_in, fea_out) -> None:
self.features_in_hook = fea_in[0] # get layer input.fea_in is a tuple object with shape of (data,)
self.features_out_hook = fea_out[0] # get layer output
return None
def get_feature_in(self):
return self.features_in_hook
def get_feature_out(self):
return self.features_out_hook
2.使用示例
我们以获取一、2中所提及到的卷积层为例,我们获取它的输出特征图,程序如下:
# load model
model = your_model()
# load weight
model.load_state_dict(your_weight)
# get a hook
layer_name = "feature_extraction.firstconv.0.1"
feat_ext = FeatureExtractor(model=model, layer_name=layer_name)
# get feature
model.eval()
output = model(your_input) #获得特征图之前,需要进行一次前向传播
feature = feat_ext.features_out_hook #获得特征图