mmrazor蒸馏中的无插入式提取蒸馏位点的方式和跨库调用的使用

mmrazor蒸馏中的无插入式提取蒸馏位点的方式:torch模型的子模型module.register_forward_hook(将数据保存下来的一个可以调用的函数)位点的提取特征图

调用的记录数据的函数接口需要有三个参数:model、input、和output

import torch
from typing import Tuple, Any

def forward_hook(module: torch.nn.Module, inputs: Tuple,  outputs: Any):
    data_buffer.append(outputs)


data_buffer = list()
input = torch.randn(1, 3, 50, 50)
model = torch.nn.Sequential(
    torch.nn.Conv2d(3, 10, 3, 1, 1),
    torch.nn.Conv2d(10, 5, 3, 1, 1),
    torch.nn.Conv2d(5, 2, 3, 2, 1),
    torch.nn.Conv2d(2, 4, 3, 1, 1)
)

model_record_name = '2'
for name, module in model.named_modules():
    if name == model_record_name:  # 查找模型中命名的模型部分和记录的蒸馏位点名称是否一致,如果一致,就...
        module.register_forward_hook(forward_hook)  # module.register_forward_hook是pytorch里的东西
        break

output = model(input)
print(data_buffer[0].shape)

可以使用这种方式进行跨库调用
在这里插入图片描述
跨库调用的demo

from mmengine.hub import get_model
from mmengine.config import Config

cfg = {'cfg_path': 'mmseg::all_changed/baseline-convnext-tiny_upernet-rotate.py'}
cfg = Config._dict_to_config_dict(cfg)


if cfg.get('cfg_path', None) and not cfg.get('type', None):
    model = get_model(**cfg)
    
print(type(model)) 

for name, module in model.named_modules():
    print(name)

demo中的model.named_modules()可以用来查看可以使用的蒸馏位点名称,放入到配置文件的source参数中
在这里插入图片描述

在这里插入图片描述

上下文形式,记录蒸馏位点数据

import torch.nn as nn
import torch
from typing import Type

class Registry:
    def __init__(self) -> None:
        self.data_buffer = list()
        
    def __enter__(self, ):
        self._data_buffer = list()
        
    def record_data_hook(self, model: nn.Module, input: Type, output: Type):
        self.data_buffer.append(output)        
        
    def __exit__(self, *args, **kwargs):
        pass


input = torch.randn(16, 3, 512, 512)
model = torch.nn.Sequential(
    torch.nn.Conv2d(3, 10, 3, 1, 1),
    torch.nn.Conv2d(10, 5, 3, 1, 1),
    torch.nn.Conv2d(5, 2, 3, 2, 1),
    torch.nn.Conv2d(2, 4, 3, 1, 1)
)

registry = Registry()

source = '2'
for name, module in model.named_modules():
    if name == source:
        module.register_forward_hook(registry.record_data_hook)
        break

with registry:              # 进入时清空;前向传播时记录数据到data_buffer
    _ = model(input)
    
    
print("拿到了forward时特定位点的特征图: {}".format(registry.data_buffer[0].shape))

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值