MMrazor官网中知识蒸馏新手入门知识点

本文详细介绍了OpenMMlab的MMRazor库,涉及配置文件结构、训练和测试不同算法的方法,以及关键概念如算法、BaseAlgorithm、recorder和Delivery的使用。涵盖了神经架构搜索、剪枝和知识蒸馏等模型压缩技术的应用。
摘要由CSDN通过智能技术生成

 1.config文件构成

mmxxx:指的是OpenMMlab的一些任务仓库,如mmdet。

_base_:包括数据集,实验设置和模型架构的配置。

vanilla:mmrazor拥有的任务模型。

2.训练不同类型的算法

python tools/train.py ${CONFIG_FILE} [optional arguments]

例如:

python./tools/train.py
config/distill/mmcls/kd/kd_logits_r34_r18_8xb32_in1k.py
--work-dir your_work_dir

3.测试模型 

python tools/test.py ${CONFIG_FILE} ${CHECKPOINT_PATH} [optional arguments]

 例如:

python ./tools/test.py 
configs/distill/mmseg/cwd/cwd_logits_pspnet_r101_d8_pspnet_r18_d8_512x1024_cityscapes_80k.py 
your_splitted_checkpoint_path --show

4.几个关键概念 

4.1Algorithm

(1)什么是算法

       mmrazor中包含四种主流技术,分别是神经架构搜索(NAS),剪枝(Purning)和知识蒸馏(KD)。算法是这些技术的通用项目,它在MMRazor中的作用和classifier在MMClassification和detector在MMDtection中的作用相同。

(2)关于基本算法

       所有模型压缩算法分为四个子目录如下图所示。这些算法必须继承自BaseAlgorithm。

(3)如何在MMRazor中使用现有算法

1)配置将要精简的架构

如果要直接使用OpenMMLab中的模型配置,下图中是相关示例。

_base_ = [
    'mmdet::_base_/models/faster_rcnn_r50_fpn.py',
]

architecture = _base_.model

2)将注册的算法应用于自己的体系结构

      这里需要注意在config中的arg名称是model而非argorithm。

model = dict(
    type='BaseAlgorithm',
    architecture=architecture)

3)将一些自定义钩子或循环应用于自己的算法中

*自定义挂钩

custom_hooks = [
    dict(type='NaiveVisualizationHook', priority='LOWEST'),
]

*自定义循环

_base_ = ['./spos_shufflenet_supernet_8xb128_in1k.py']

# To chose from ['train_cfg', 'val_cfg', 'test_cfg'] based on your loop type
train_cfg = dict(
    _delete_=True,
    type='mmrazor.EvolutionSearchLoop',
    dataloader=_base_.val_dataloader,
    evaluator=_base_.val_evaluator)

val_cfg = dict()
test_cfg = dict()

4.2 recorder

(1)什么是recorder

recorder是一个上下文管理器,用于记录模型转发过程中的各种中间结果。它可以通过在某些蒸馏算法中记录源数据来帮助完成数据交付。并且它还可用于获取一些特定数据,用于可视化分析或您想要的其他功能。

(2)recorder种类

(3)函数输出记录器

 实例化时,需要传递参数。使用函数路径进行实例化。

# instantiate with specifying used path
r1 = FunctionOutputsRecorder('toy_module.toy_func')

(4)方法输出记录器

 实例化时,需要传递参数。使用函数路径进行实例化。

# instantiate with specifying used path
r1 = MethodOutputsRecorder('toy_module.Toy.toy_func')

(5)模块输出记录器和模块输入记录器

不同于上面两个,使用模块名称进行实例化。

# instantiate with specifying module name.
r1 = ModuleOutputsRecorder('conv1')

(6)参数记录器

 使用参数名称而不是模块名称进行实例化。

# instantiate with specifying parameter name.
r1 = ParameterRecorder('toy_conv.weight')

(7)RecorderManager

 上下文管理器,可用于管理各种类型的记录器。

# configure multi-recorders
conv1_rec = ConfigDict(type='ModuleOutputs', source='conv1')
conv2_rec = ConfigDict(type='ModuleOutputs', source='conv2')
func_rec = ConfigDict(type='MethodOutputs', source='toy_module.Toy.toy_func')
# instantiate RecorderManager with a dict that contains recorders' configs,
# you can customize their keys.
manager = RecorderManager(
    {'conv1_rec': conv1_rec,
     'conv2_rec': conv2_rec,
     'func_rec': func_rec})

4.3 Delivery 

(1)什么是delivery

Delivery知识蒸馏中使用的一种机制,它通过在教师模型和学生模型之间传递和重写这些中间结果来调整它们之间的中间结果。

(2)delivery方式

FunctionOutputsDelivery和MethodOutputsDelivery

(3)函数输出交付

FunctionOutputsDelivery用于在教师模型和学生模型之间对齐函数的中间结果。

1)教师向学生传递单个函数的输出。

delivery = FunctionOutputsDelivery(max_keep_data=1, func_path='toy_module.toy_func')

# override_data is False, which means that not override the data with
# the recorded data. So it will get the original output of toy_func
# in teacher model, and it is also recorded to be deliveried to the student.
delivery.override_data = False
with delivery:
    output_teacher = toy_module.toy_func()

# override_data is True, which means that override the data with
# the recorded data, so it will get the output of toy_func
# in teacher model rather than the student's.
delivery.override_data = True
with delivery:
    output_student = toy_module.toy_func()

print(output_teacher == output_student)

2)教师多输出向学生传递

delivery = FunctionOutputsDelivery(
    max_keep_data=2, func_path='toy_module.toy_func')

delivery.override_data = False
with delivery:
    output1_teacher = toy_module.toy_func()
    output2_teacher = toy_module.toy_func()

delivery.override_data = True
with delivery:
    output1_student = toy_module.toy_func()
    output2_student = toy_module.toy_func()

print(output1_teacher == output1_student and output2_teacher == output2_student)

(4)方法输出交付

1)无交付

# main.py
from mmcls.models.utils import Augments
from mmrazor.core import MethodOutputsDelivery

augments_cfg = dict(type='BatchMixup', alpha=1., num_classes=10, prob=1.0)
augments = Augments(augments_cfg)

imgs = torch.randn(2, 3, 32, 32)
label = torch.randint(0, 10, (2,))

imgs_teacher, label_teacher = augments(imgs, label)
imgs_student, label_student = augments(imgs, label)

print(torch.equal(label_teacher, label_student))
print(torch.equal(imgs_teacher, imgs_student))

2)有交付

delivery = MethodOutputsDelivery(
    max_keep_data=1, method_path='mmcls.models.utils.Augments.__call__')

delivery.override_data = False
with delivery:
    imgs_teacher, label_teacher = augments(imgs, label)

delivery.override_data = True
with delivery:
    imgs_student, label_student = augments(imgs, label)

print(torch.equal(label_teacher, label_student))
print(torch.equal(imgs_teacher, imgs_student))

(5)蒸馏交付管理器

 DistillDeliveryManager实际上是一个上下文管理器,用于管理交付。

from mmcls.models.utils import Augments
from mmrazor.core import DistillDeliveryManager

augments_cfg = dict(type='BatchMixup', alpha=1., num_classes=10, prob=1.0)
augments = Augments(augments_cfg)

distill_deliveries = [
    ConfigDict(type='MethodOutputs', max_keep_data=1,
               method_path='mmcls.models.utils.Augments.__call__')]

# instantiate DistillDeliveryManager
manager = DistillDeliveryManager(distill_deliveries)

imgs = torch.randn(2, 3, 32, 32)
label = torch.randint(0, 10, (2,))

manager.override_data = False
with manager:
    imgs_teacher, label_teacher = augments(imgs, label)

manager.override_data = True
with manager:
    imgs_student, label_student = augments(imgs, label)

print(torch.equal(label_teacher, label_student))
print(torch.equal(imgs_teacher, imgs_student))

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值