MMCV核心组件知识整理

3.1 MMCV整体概述

提供了上层框架需要的 hook 机制以及可以直接使用的 runner

MMCV 提供了非常多的高性能 cuda 算子及其 python 接口

3.2 FileHandler

可参考https://zhuanlan.zhihu.com/p/336097883

fileio中的核心组件,设计文件读写。

mmcv提供了底层逻辑的读写handler,目前支持的有.json/.yaml/.yml/.pickle/.pkl文件

# 具体用法
import mmcv

# load data from a file
data = mmcv.load('test.json')
data = mmcv.load('test.yaml')
data = mmcv.load('test.pkl')

mmcv.dump(data, 'out.pkl')

mmcv支持自定义拓展的文件格式(即需要的文件格式不在上述列表),链接中给了.npy的例子。

3.3 FileClient

其作用是对外提供统一的文件内容获取 API,主要用于训练过程中数据的后端读取,通过用户选择默认或者自定义不同的 FileClient 后端,可以轻松实现文件缓存、文件加速读取等等功能

https://zhuanlan.zhihu.com/p/339190576

FileClinet用法示例,其实际调用在 mmseg/datasets/pipelines/loading.py/LoadImageFromFile 类中

class LoadImageFromFile(object): # 加载图片到内存中
    """Load an image from file.

    Required keys are "img_prefix" and "img_info" (a dict that must contain the
    key "filename"). Added or updated keys are "filename", "img", "img_shape",
    "ori_shape" (same as `img_shape`), "pad_shape" (same as `img_shape`),
    "scale_factor" (1.0) and "img_norm_cfg" (means=0 and stds=1).

    Args:
        to_float32 (bool): Whether to convert the loaded image to a float32
            numpy array. If set to False, the loaded image is an uint8 array.
            Defaults to False.
        color_type (str): The flag argument for :func:`mmcv.imfrombytes`.
            Defaults to 'color'.
        file_client_args (dict): Arguments to instantiate a FileClient.
            See :class:`mmcv.fileio.FileClient` for details.
            Defaults to ``dict(backend='disk')``.
        imdecode_backend (str): Backend for :func:`mmcv.imdecode`. Default:
            'cv2'
    """

    def __init__(self,
                 to_float32=False,
                 color_type='color',
                 file_client_args=dict(backend='disk'),
                 imdecode_backend='cv2'):
        self.to_float32 = to_float32
        self.color_type = color_type
        # 默认是disk后端
        self.file_client_args = file_client_args.copy()
        self.file_client = None
        self.imdecode_backend = imdecode_backend

    def __call__(self, results):
        """Call functions to load image and get image meta information.

        Args:
            results (dict): Result dict from :obj:`mmseg.CustomDataset`.

        Returns:
            dict: The dict contains loaded image and meta information.
        """

        if self.file_client is None:
            self.file_client = mmcv.FileClient(**self.file_client_args)

        if results.get('img_prefix') is not None:
            filename = osp.join(results['img_prefix'],
                                results['img_info']['filename'])
        else:
            filename = results['img_info']['filename']
        # 读取图片字节内容
        img_bytes = self.file_client.get(filename)
        # 对字节内容进行解码
        img = mmcv.imfrombytes(
            img_bytes, flag=self.color_type, backend=self.imdecode_backend)
        if self.to_float32:
            img = img.astype(np.float32)

        results['filename'] = filename
        results['ori_filename'] = results['img_info']['filename']
        results['img'] = img
        results['img_shape'] = img.shape
        results['ori_shape'] = img.shape
        # Set initial values for default meta_keys
        results['pad_shape'] = img.shape
        results['scale_factor'] = 1.0
        num_channels = 1 if len(img.shape) < 3 else img.shape[2]
        results['img_norm_cfg'] = dict(
            mean=np.zeros(num_channels, dtype=np.float32),
            std=np.ones(num_channels, dtype=np.float32),
            to_rgb=False)
        return results

    def __repr__(self):
        repr_str = self.__class__.__name__
        repr_str += f'(to_float32={self.to_float32},'
        repr_str += f"color_type='{self.color_type}',"
        repr_str += f"imdecode_backend='{self.imdecode_backend}')"
        return repr_str

扩展开发示例提供了img文件和annotations 文件不在同一个地方的例子。

3.4 Config

Config 主要是提供各种格式的配置文件解析功能,包括 py、json、ymal 和 yml,是一个非常基础常用类

https://zhuanlan.zhihu.com/p/346203167

3.4.1 Config用法汇总

3.4.1.1 通过 dict 生成 config

mmseg/configs目录下的很多文件是用这个方法定义的

cfg = Config(dict(a=1, b=dict(b1=[0, 1])))

# 可以通过 .属性方式访问,比较方便
cfg.b.b1 # [0, 1]
3.4.1.2 通过 配置文件 生成 config

该功能最为常用,配置文件可以是 py、yaml、yml 和 json 格式。

cfg = Config.fromfile('tests/data/config/a.py')

cfg.filename
cfg.item4 # 'test'
cfg # 打印 config path,和字典内容...
3.4.1.3 自动替换预定义变量

假设h.py文件里面存储的内容是:

cfg_dict = dict(
        item1='{{fileBasename}}',
        item2='{{fileDirname}}',
        item3='abc_{{fileBasenameNoExtension }}')

则可以通过参数 use_predefined_variables 实现自动替换预定义变量功能

# cfg_file 文件名是 h.py
cfg = Config.fromfile(cfg_file, use_predefined_variables=True)
print(cfg.pretty_text)

# 输出
item1 = 'h.py'
item2 = 'config 文件路径'
item3 = 'abc_h'

该参数主要用途是自动替换 Config 类中已经预定义好的变量模板为真实值,在某些场合有用,目前支持 4 个变量:fileDirname、fileBasename、fileBasenameNoExtension 和 fileExtname,预定义变量参考自 VS Code

如果 use_predefined_variables=False( 默认为 True ),则不会进行任何替换。

3.4.1.4 导入自定义模块

Config.fromfile 函数除了有 filenameuse_predefined_variables 参数外,还有 import_custom_modules,默认是 True,即当 cfg中存在 custom_imports 键时候会对里面的内容进行自动导入,其输入格式要么是 str 要么是 list[str],表示待导入的模块路径,一个典型用法是:

在mmseg/datasets目录下新建greenscreen.py时,需要在__init__里面加入

from .greenscreen import GreenScreenDataset

但是上述做法在某些场景下会比较麻烦。例如该模块处于非常深的层级,那么就需要逐层修改 __init__.py,有了本参数,便可以采用如下做法:

# .py 文件里面存储如下内容
custom_imports = dict(
    imports=['mmdet.models.backbones.mobilenet'],
    allow_failed_imports=False)

# 自动导入 mmdet.models.backbones.mobilenet
Config.fromfile(cfg_file, import_custom_modules=True)
3.4.1.5 合并多个配置文件

(1) 从 base 文件中合并 Config 支持基于单个 base 配置文件,然后合并其余配置,最终得到一个汇总配置,该功能在各大上层框架中使用非常频繁,可以极大的增加配置复用性。一个典型用法是:

# base.py 内容

item1 = [1, 2]
item2 = {'a': 0}
item3 = True
item4 = 'test'

# d.py 内容
_base_ = './base.py'
item1 = [2, 3]
item2 = {'a': 1}
item3 = False
item4 = 'test_base'

# 用法
cfg = Config.fromfile('d.py')

# 输出
item1 = [2, 3]
item2 = dict(a=1)
item3 = False
item4 = 'test_base'

(2) 从多个 base 文件中合并 Config 同时也支持多个 base 文件合并得到最终配置,用户只需要在非 base 配置文件中将类似 _base_ = './base.py'改成 _base_ = ['./base.py',...] 即可。如配置configs/unet/deeplabv3_unet_s5_d16_256x256_40k_greenscreen.py时

_base_ = [
    '../_base_/models/deeplabv3_unet_s5-d16.py', '../_base_/datasets/greenscreen.py',
    '../_base_/default_runtime.py', '../_base_/schedules/schedule_40k.py'
]
model = dict(test_cfg=dict(crop_size=(256, 256), stride=(170, 170)))
evaluation = dict(metric='mDice')

需要特别强调的是

  • base 文件的 key 是不允许改的,必须是 _base_ ,否则程序不知道哪个字段才是 base
  • 多个 base 以 list 方式并行构建模式下,不允许多个 base 文件中有相同字段,程序会报 Duplicate Key Error,因为此时不知道以哪个配置为主

(3) 合并字典到配置 通过 cfg.merge_from_dict 函数接口可以实现对字典内容进行合并,典型用法如下:

cfg_file = osp.join(data_path, 'config/a.py')
cfg = Config.fromfile(cfg_file)
input_options = {'item2.a': 1, 'item2.b': 0.1, 'item3': False}
cfg.merge_from_dict(input_options)

# 原始 a.py 内容为:
item1 = [1, 2]
item2 = {'a': 0}
item3 = True
item4 = 'test'

# 进行合并后, cfg 内容
item1 = [1, 2]
item2 = dict(a=1, b=0.1)
item3 = False
item4 = 'test'

(4) allow_list_keys 模式合并 假设某个配置文件中内容为:

item = [dict(a=0), dict(b=0, c=0)]

可以通过如下方式修改 list 内容:

input_options = {'item.0.a': 1, 'item.1.b': 1}
cfg.merge_from_dict(input_options, allow_list_keys=True)

# 输出
item = [dict(a=1), dict(b=1, c=0)]

如果 input_options 内部索引越界或者 allow_list_keys=False (默认是 True),则会报错。

(5) 允许删掉特定内容 该功能也比较常用,思考如下场景:在 RetinaNet 算法中,其采用的 bbox 回归 loss 配置如下:

loss_bbox=dict(type='L1Loss', loss_weight=1.0,其他参数)

上述配置是在 base 文件中,但是在 FASF 算法中采用的是 IOULoss,现在要做的事情是在 FASF 配置中自动覆盖掉 base 配置中的 L1Loss,可以采用如下做法:

loss_bbox=dict(
    _delete_=True,
    type='IoULoss',
    eps=1e-6,
    loss_weight=1.0,
    reduction='none')

如果没有 _delete_=True 参数,则两个配置会自动合并,L1Loss 中的其他参数始终会保留,无法删除,这肯定是不正确的( IoULoss 中不需要 L1Loss 的初始化参数),现在通过引入 _delete_ 保留字则可以实现忽略 base 相关配置,直接采用新配置文件字段功能。

3.4.1.6 pretty_text 和 dump

pretty_text 函数可以将字典内容按照 PEP8 格式打印,输出结构清晰,非常好看,如下所示:

# 直接打印字典内容
print(cfg._cfg_dict)
# 输出
{'item1': [1, 2], 'item2': {'a': 1, 'b': 0.1}, 'item3': False, 'item4': 'test'}

# pretty 打印字典内容
print(cfg.pretty_text)
# 输出
item1 = [1, 2]
item2 = dict(a=1, b=0.1)
item3 = False
item4 = 'test'

上述功能是解决第三方库 yapf 实现。而 dump 功能就是将 cfg 内容保存,当想查看实验配置是否正确、查看实验记录以及复现以前实验结果时候非常有用。

3.4.2 Config源码解析

见 https://zhuanlan.zhihu.com/p/346203167

3.5 Registry

Registry 用于提供全局类注册器功能

3.5.1 Registry 功能和用法

https://zhuanlan.zhihu.com/p/355271993

Registry 类可以提供一种完全相似的对外装饰函数来管理构建不同的组件,例如 backbones、head 和 necks 等等,Registry 类内部其实维护的是一个全局 key-value 对。通过 Registry 类,用户可以通过字符串方式实例化任何想要的模块。

例如在 Faster R-CNN 的 backbone 模块实例化时,可以采用如下配置:

backbone=dict(
    type='ResNet', # 待实例化的类名
    depth=50, # 后面的都是对于的类初始化参数
    num_stages=4,
    out_indices=(0, 1, 2, 3),
    frozen_stages=1,
    norm_cfg=dict(type='BN', requires_grad=True),
    norm_eval=True,
    style='pytorch'),

(1) 最简实现

# 方便起见,此处并未使用类方式构建,而是直接采用全局变量

_module_dict = dict()

# 定义装饰器函数
def register_module(name):
    def _register(cls):
        _module_dict[name] = cls
        return cls

    return _register

# 装饰器用法
@register_module('one_class')
class OneTest(object):
    pass

@register_module('two_class')
class TwoTest(object):
    pass

进行简单测试:

if __name__ == '__main__':
    # 通过注册类名实现自动实例化功能
    one_test = _module_dict['one_class']()
    print(one_test)

# 输出
<__main__.OneTest object at 0x7f1d7c5acee0>

可以发现只要将所定义的简单装饰器函数作用到类名上,然后内部采用 _module_dict 保存信息即可

(2) 实现无需传入参数,自动根据类名初始化类

_module_dict = dict()

def register_module(module_name=None):
    def _register(cls):
        name = module_name
        # 如果 module_name 没有给,则自动获取
        if module_name is None:
            name = cls.__name__
        _module_dict[name] = cls
        return cls

    return _register

@register_module('one_class')
class OneTest(object):
    pass

@register_module()
class TwoTest(object):
    pass

进行简单测试:

if __name__ == '__main__':
    one_test = _module_dict['one_class']
    # 方便起见,此处仅仅打印了类对象,而没有实例化。如果要实例化,只需要 one_test() 即可
    print(one_test)
    two_test = _module_dict['TwoTest']
    print(two_test)

# 输出
<class '__main__.OneTest '>
<class '__main__.TwoTest'>

(3) 实现重名注册强制报错功能

def register_module(module_name=None):
    def _register(cls):
        name = module_name
        if module_name is None:
            name = cls.__name__

        # 如果重名注册,则强制报错
        if name in _module_dict:
            raise KeyError(f'{module_name} is already registered '
                           f'in {name}')
        _module_dict[name] = cls
        return cls

    return _register

新增一个 force 参数即可

def register_module(module_name=None,force=False):
    def _register(cls):
        name = module_name
        if module_name is None:
            name = cls.__name__

        # 如果重名注册,则强制报错
        if not force and name in _module_dict:
            raise KeyError(f'{module_name} is already registered '
                           f'in {name}')
        _module_dict[name] = cls
        return cls

    return _register

测试:

@register_module('one_class')
class OneTest(object):
    pass

@register_module('one_class',True)
class TwoTest(object):
    pass

if __name__ == '__main__':
    one_test = _module_dict['one_class']
    print(one_test)

# 输出
<class '__main__.TwoTest'>

(5) 实现直接注册类功能

实现直接注册类的功能,只需要 _module_dict['name'] = module_class 即可。

3.5.2 Registry 类实现

class Registry:
    def __init__(self, name):
        # 可实现注册类细分功能
        self._name = name 
        # 内部核心内容,维护所有的已经注册好的 class
        self._module_dict = dict()

    def _register_module(self, module_class, module_name=None, force=False):
        if not inspect.isclass(module_class):
            raise TypeError('module must be a class, '
                            f'but got {type(module_class)}')

        if module_name is None:
            module_name = module_class.__name__
        if not force and module_name in self._module_dict:
            raise KeyError(f'{module_name} is already registered '
                           f'in {self.name}')
        # 最核心代码
        self._module_dict[module_name] = module_class

    # 装饰器函数
    def register_module(self, name=None, force=False, module=None):
        if module is not None:
            # 如果已经是 module,那就知道 增加到字典中即可
            self._register_module(
                module_class=module, module_name=name, force=force)
            return module

        # 最标准用法
        # use it as a decorator: @x.register_module()
        def _register(cls):
            self._register_module(
                module_class=cls, module_name=name, force=force)
            return cls
        return _register

在 MMCV 中所有的类实例化都是通过 build_from_cfg 函数实现,做的事情非常简单,就是给定 module_name,然后从 self._module_dict 提取即可。

def build_from_cfg(cfg, registry, default_args=None):
    args = cfg.copy()

    if default_args is not None:
        for name, value in default_args.items():
            args.setdefault(name, value)

    obj_type = args.pop('type') # 注册 str 类名
    if is_str(obj_type):
        # 相当于 self._module_dict[obj_type]
        obj_cls = registry.get(obj_type)
        if obj_cls is None:
            raise KeyError(
                f'{obj_type} is not in the {registry.name} registry')

    # 如果已经实例化了,那就直接返回
    elif inspect.isclass(obj_type):
        obj_cls = obj_type
    else:
        raise TypeError(
            f'type must be a str or valid type, but got {type(obj_type)}')

    # 最终初始化对于类,并且返回,就完成了一个类的实例化过程
    return obj_cls(**args)

一个完整的使用例子如下:

CONVERTERS = Registry('converter')

@CONVERTERS.register_module()
class Converter1(object):
    def __init__(self, a, b):
        self.a = a
        self.b = b

converter_cfg = dict(type='Converter1', a=a_value, b=b_value)
converter = build_from_cfg(converter_cfg,CONVERTERS)

3.6 Hook

处理被拦截的函数调用、事件、消息的代码,被称为钩子(hook)

在我们熟知的 pytorch 中某个 tensor 或者 module 都有 register_hook(hook_fn) 函数,通过注册 hook,可以拦截和修改某些中间变量的值。

3.6.1 Hook如何用

在 python 中要实现 hook 机制,非常简单,传入一个函数即可,如下是一个简单的 hook,该 hook 的功能是打印内部变量

def hook(d):
   print(d)

def add(a,b,c,hook_fn=None)
   sum1=a+b
   if hook_fn is not None:
       hook_fn(sum1)
    return sum1+c

# 调用
add(1,2,3,hook)

在 PyTorch 中提供了非常方便的注册机制,用户可以随意插入任何函数来捕获中间过程,下面是一个简单的示例

import torch
from torch import nn
from mmcv.cnn import constant_init

# hook 函数,其三个参数不能修改(参数名随意),本质上是 PyTorch 内部回调函数

# module 本身对象
# input 该 module forward 前输入
# output 该 module forward 后输出
def forward_hook_fn(module, input, output):
    print('weight', module.weight.data)
    print('bias', module.bias.data)
    print('input', input)
    print('output', output)

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.fc = nn.Linear(3, 1)
        self.fc.register_forward_hook(forward_hook_fn)
        constant_init(self.fc, 1)

    def forward(self, x):
        o = self.fc(x)
        return o

运行输出:

if __name__ == '__main__':
    model = Model()
    x = torch.Tensor([[0.0, 1.0, 2.0]])
    y = model(x)

# 输出
weight:tensor([[1., 1., 1.]])
bias: tensor([0.])
input: (tensor([[0., 1., 2.]]),)
output:tensor([[3.]], grad_fn=<AddmmBackward>)
  • 1
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值