1.config文件构成
mmxxx:指的是OpenMMlab的一些任务仓库,如mmdet。
_base_:包括数据集,实验设置和模型架构的配置。
vanilla:mmrazor拥有的任务模型。
2.训练不同类型的算法
python tools/train.py ${CONFIG_FILE} [optional arguments]
例如:
python./tools/train.py
config/distill/mmcls/kd/kd_logits_r34_r18_8xb32_in1k.py
--work-dir your_work_dir
3.测试模型
python tools/test.py ${CONFIG_FILE} ${CHECKPOINT_PATH} [optional arguments]
例如:
python ./tools/test.py
configs/distill/mmseg/cwd/cwd_logits_pspnet_r101_d8_pspnet_r18_d8_512x1024_cityscapes_80k.py
your_splitted_checkpoint_path --show
4.几个关键概念
4.1Algorithm
(1)什么是算法
mmrazor中包含四种主流技术,分别是神经架构搜索(NAS),剪枝(Purning)和知识蒸馏(KD)。算法是这些技术的通用项目,它在MMRazor中的作用和classifier在MMClassification和detector在MMDtection中的作用相同。
(2)关于基本算法
所有模型压缩算法分为四个子目录如下图所示。这些算法必须继承自BaseAlgorithm。
(3)如何在MMRazor中使用现有算法
1)配置将要精简的架构
如果要直接使用OpenMMLab中的模型配置,下图中是相关示例。
_base_ = [
'mmdet::_base_/models/faster_rcnn_r50_fpn.py',
]
architecture = _base_.model
2)将注册的算法应用于自己的体系结构
这里需要注意在config中的arg名称是model而非argorithm。
model = dict(
type='BaseAlgorithm',
architecture=architecture)
3)将一些自定义钩子或循环应用于自己的算法中
*自定义挂钩
custom_hooks = [
dict(type='NaiveVisualizationHook', priority='LOWEST'),
]
*自定义循环
_base_ = ['./spos_shufflenet_supernet_8xb128_in1k.py']
# To chose from ['train_cfg', 'val_cfg', 'test_cfg'] based on your loop type
train_cfg = dict(
_delete_=True,
type='mmrazor.EvolutionSearchLoop',
dataloader=_base_.val_dataloader,
evaluator=_base_.val_evaluator)
val_cfg = dict()
test_cfg = dict()
4.2 recorder
(1)什么是recorder
recorder是一个上下文管理器,用于记录模型转发过程中的各种中间结果。它可以通过在某些蒸馏算法中记录源数据来帮助完成数据交付。并且它还可用于获取一些特定数据,用于可视化分析或您想要的其他功能。
(2)recorder种类
(3)函数输出记录器
实例化时,需要传递参数。使用函数路径进行实例化。
# instantiate with specifying used path
r1 = FunctionOutputsRecorder('toy_module.toy_func')
(4)方法输出记录器
实例化时,需要传递参数。使用函数路径进行实例化。
# instantiate with specifying used path
r1 = MethodOutputsRecorder('toy_module.Toy.toy_func')
(5)模块输出记录器和模块输入记录器
不同于上面两个,使用模块名称进行实例化。
# instantiate with specifying module name.
r1 = ModuleOutputsRecorder('conv1')
(6)参数记录器
使用参数名称而不是模块名称进行实例化。
# instantiate with specifying parameter name.
r1 = ParameterRecorder('toy_conv.weight')
(7)RecorderManager
上下文管理器,可用于管理各种类型的记录器。
# configure multi-recorders
conv1_rec = ConfigDict(type='ModuleOutputs', source='conv1')
conv2_rec = ConfigDict(type='ModuleOutputs', source='conv2')
func_rec = ConfigDict(type='MethodOutputs', source='toy_module.Toy.toy_func')
# instantiate RecorderManager with a dict that contains recorders' configs,
# you can customize their keys.
manager = RecorderManager(
{'conv1_rec': conv1_rec,
'conv2_rec': conv2_rec,
'func_rec': func_rec})
4.3 Delivery
(1)什么是delivery
Delivery
是知识蒸馏中使用的一种机制,它通过在教师模型和学生模型之间传递和重写这些中间结果来调整它们之间的中间结果。
(2)delivery方式
FunctionOutputsDelivery和MethodOutputsDelivery
(3)函数输出交付
FunctionOutputsDelivery
用于在教师模型和学生模型之间对齐函数的中间结果。
1)教师向学生传递单个函数的输出。
delivery = FunctionOutputsDelivery(max_keep_data=1, func_path='toy_module.toy_func')
# override_data is False, which means that not override the data with
# the recorded data. So it will get the original output of toy_func
# in teacher model, and it is also recorded to be deliveried to the student.
delivery.override_data = False
with delivery:
output_teacher = toy_module.toy_func()
# override_data is True, which means that override the data with
# the recorded data, so it will get the output of toy_func
# in teacher model rather than the student's.
delivery.override_data = True
with delivery:
output_student = toy_module.toy_func()
print(output_teacher == output_student)
2)教师多输出向学生传递
delivery = FunctionOutputsDelivery(
max_keep_data=2, func_path='toy_module.toy_func')
delivery.override_data = False
with delivery:
output1_teacher = toy_module.toy_func()
output2_teacher = toy_module.toy_func()
delivery.override_data = True
with delivery:
output1_student = toy_module.toy_func()
output2_student = toy_module.toy_func()
print(output1_teacher == output1_student and output2_teacher == output2_student)
(4)方法输出交付
1)无交付
# main.py
from mmcls.models.utils import Augments
from mmrazor.core import MethodOutputsDelivery
augments_cfg = dict(type='BatchMixup', alpha=1., num_classes=10, prob=1.0)
augments = Augments(augments_cfg)
imgs = torch.randn(2, 3, 32, 32)
label = torch.randint(0, 10, (2,))
imgs_teacher, label_teacher = augments(imgs, label)
imgs_student, label_student = augments(imgs, label)
print(torch.equal(label_teacher, label_student))
print(torch.equal(imgs_teacher, imgs_student))
2)有交付
delivery = MethodOutputsDelivery(
max_keep_data=1, method_path='mmcls.models.utils.Augments.__call__')
delivery.override_data = False
with delivery:
imgs_teacher, label_teacher = augments(imgs, label)
delivery.override_data = True
with delivery:
imgs_student, label_student = augments(imgs, label)
print(torch.equal(label_teacher, label_student))
print(torch.equal(imgs_teacher, imgs_student))
(5)蒸馏交付管理器
DistillDeliveryManager
实际上是一个上下文管理器,用于管理交付。
from mmcls.models.utils import Augments
from mmrazor.core import DistillDeliveryManager
augments_cfg = dict(type='BatchMixup', alpha=1., num_classes=10, prob=1.0)
augments = Augments(augments_cfg)
distill_deliveries = [
ConfigDict(type='MethodOutputs', max_keep_data=1,
method_path='mmcls.models.utils.Augments.__call__')]
# instantiate DistillDeliveryManager
manager = DistillDeliveryManager(distill_deliveries)
imgs = torch.randn(2, 3, 32, 32)
label = torch.randint(0, 10, (2,))
manager.override_data = False
with manager:
imgs_teacher, label_teacher = augments(imgs, label)
manager.override_data = True
with manager:
imgs_student, label_student = augments(imgs, label)
print(torch.equal(label_teacher, label_student))
print(torch.equal(imgs_teacher, imgs_student))