官方学习文档地址:https://mmcv.readthedocs.io/zh_CN/latest/understand_mmcv/registry.html
注意args是mmcv的核心就行
1.环境搭建
mmcv需要基cuda否则会报错:connot import* form mmcv.*
pip install mmcv-full
pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113
2.注册器Registry
2.1定义
注册器是由字符串到类的映射,或者或是字典到类映射。
注册器类似一个参数列表,基于一个注册器我们可以继承出很多转换器
2.2参数列表
Registry(name, build_func=None, parent=None, scope=None)
第一个参数那么是默认注册器的名字,和我们在实例化的时候的名字不是同一个
第二个参数构建函数,一般需要我们自己编写,使用默认的。
2.3实例
通过 @CONVERTERS.register_module() 装饰所实现的模块,字符串和类之间的映射就可以由 CONVERTERS 构建和维护
from mmcv.utils import Registry
# 创建转换器(converter)的注册器(registry)
CONVERTERS = Registry('converter')
from .builder import CONVERTERS
# 使用注册器管理模块
@CONVERTERS.register_module()
class Converter1(object):
def __init__(self, a, b):
self.a = a
self.b = b
3.执行器
有两种执行器,一种是以epoch为单位,一种是以是以iter为单位训练。以EpochBasedRunner为例做介绍
3.1定义
相当于epoch_fit的for循环,和torch一样只支持训练和验证。
3.1实例
(1) dataloader、model 和优化器等类初始化
# 模型类初始化
model=...
# 优化器类初始化,典型值 cfg.optimizer = dict(type='SGD', lr=0.1, momentum=0.9, weight_decay=0.0001)
optimizer = build_optimizer(model, cfg.optimizer)
# 工作流对应的 dataloader 初始化
data_loaders = [
build_dataloader(
ds,
cfg.data.samples_per_gpu,
cfg.data.workers_per_gpu,
...) for ds in dataset
]
``
(2) runner 类(训练器)初始化
```python
runner = build_runner(
# cfg.runner 典型配置为
# runner = dict(type='EpochBasedRunner', max_epochs=200)
cfg.runner,
default_args=dict(
model=model,
batch_processor=None,
optimizer=optimizer,
logger=logger))
(3)规定一些训练参数
# 注册定制必需的 hook
runner.register_training_hooks(
# lr相关配置,典型为
# lr_config = dict(policy='step', step=[100, 150])
cfg.lr_config,
# 优化相关配置,例如 grad_clip 等
optimizer_config,
# 权重保存相关配置,典型为
# checkpoint_config = dict(interval=1),每个单位都保存权重
cfg.checkpoint_config,
# 日志相关配置
cfg.log_config,
...)
# 注册用户自定义 hook
# 例如想使用 ema 功能,则可以设置 custom_hooks=[dict(type='EMAHook')]
if cfg.get('custom_hooks', None):
custom_hooks = cfg.custom_hooks
for hook_cfg in cfg.custom_hooks:
hook_cfg = hook_cfg.copy()
priority = hook_cfg.pop('priority', 'NORMAL')
hook = build_from_cfg(hook_cfg, HOOKS)
runner.register_hook(hook, priority=priority)
(4)开启训练流
之前的工作都是定义训练内容,此时才是真正开启了训练
runner.run(data_loaders, cfg.workflow)
4.label文件输入输出
mmcv不支持xml格式的label文件,目前只支持 json、yaml、txt 和 pickle。
import mmcv
# 从文件中读取数据
data = mmcv.load('test1.json')
data = mmcv.load('test.yaml')
data = mmcv.load('test.pkl')
# 从文件对象中读取数据
print(data)
with open('test1.json', 'r' ,encoding='utf-8') as f:
data = mmcv.load(f, file_format='json')
print(data)
5.图片文件输入输出
语法和opencv很像:
import mmcv
img = mmcv.imread('./test.png')
mmcv.imshow(img)
img = mmcv.imread('./test.png', flag='grayscale')
mmcv.imshow(img)
mmcv.imwrite(img, 'out.jpg')
import cv2
img = cv2.imread("test.png")
imgGrey = cv2.imread("test.png", 0)
cv2.imshow("img", img)
cv2.imshow("imgGrey", imgGrey)
cv2.waitKey()
cv2.imwrite("Copy.jpg", imgGrey)