三维点云深度网络 PointNeXt 源码阅读 (I) —— 注册机制与参数解析

Title: 三维点云深度网络 PointNeXt 源码阅读 (I) —— 注册机制与参数解析



关联博文
[1] 三维点云深度网络 PointNeXt 的安装配置与测试
[2] 三维点云深度网络 PointNeXt 源码阅读 (I) —— 注册机制与参数解析 ⇐ \qquad \Leftarrow 本篇
[3] 三维点云深度网络 PointNeXt 源码阅读 (II) —— 点云数据集构造与预处理
[4] 三维点云深度网络 PointNeXt 源码阅读 (III) —— 骨干网络模型
[5] 三维点云深度网络 PointNeXt 源码阅读 (IV) —— PointNeXt-B


前言

学习了部分 PointNeXt 源码, 先记录一下, 以备忘.

本篇博文分为两部分, 注册机制和参数解析, 理解的重点是注册机制.

相关注释和调试信息都是基于下面测试 session.

CUDA_VISIBLE_DEVICES=0,1 python examples/segmentation/main.py \
				--cfg cfgs/s3dis/pointnext-s.yaml  mode=train

I. 注册机制

所谓注册机制是指 PointNeXt 中模块/类的注册机制, 可以实现字符串到模块/类的映射. 换而言之, 这种注册机制就可以实现读入配置文件中的参数字符串, 进而直接映射获得对应的模块/类的实例. 这部分的实现 PointNeXt 源作者参考了 mmcv 中的注册机制.

1. 注册类 Registry

注册机制本身是通过注册类 class Registry 实现的, 其中关键方法有:

方法解释
__init__()类初始化, 其中也初始化了注册模块字典 self._module_dict = dict()
get(self, key)实现从字符串到类的映射, 以字符串 key 映射到 self._module_dict 中注册的类 self._module_dict[real_key]
register_module(self, name=None, force=False, module=None)注册模块, 实现对模块/类的注册, 也用作为对模块/类进行装饰的装饰器
_register(cls)装饰器 register_module 内部的包装函数 wrapper. 适用于装饰情况下的调用, 参数 cls 就是传递进来的需要被装饰的类. 这个包装函数在 cls 类定义的基础上, 先调用_register_module(self, module_class, module_name=None, force=False) 实现了类的注册 self._module_dict[name] = module_class, 然后没有任何其他处理而直接 return 了类定义 cls
build_from_cfg(cfg, registry, default_args=None)从配置字典构建模块/类实例, 实现由字符串生成模块/类实例

openpoints/utils/registry.py 中定义了 Registry 类, 添加注释如下.

# Acknowledgement: built upon mmcv
import inspect
import warnings
from functools import partial
import copy 

class Registry:
    """A registry to map strings to classes.
    Registered object could be built from registry.
    Example:
        >>> MODELS = Registry('models')
        >>> @MODELS.register_module()
        >>> class ResNet:
        >>>     pass
        >>> resnet = MODELS.build(dict(NAME='ResNet'))
    Please refer to https://mmcv.readthedocs.io/en/latest/registry.html for
    advanced useage.
    Args:
        name (str): Registry name.
        build_func(func, optional): Build function to construct instance from
            Registry, func:`build_from_cfg` is used if neither ``parent`` or
            ``build_func`` is specified. If ``parent`` is specified and
            ``build_func`` is not given,  ``build_func`` will be inherited
            from ``parent``. Default: None.
        parent (Registry, optional): Parent registry. The class registered in
            children registry could be built from parent. Default: None.
        scope (str, optional): The scope of registry. It is the key to search
            for children registry. If not specified, scope will be the name of
            the package where class is defined, e.g. mmdet, mmcls, mmseg.
            Default: None.
    """

    def __init__(self, name, build_func=None, parent=None, scope=None):
        self._name = name
        self._module_dict = dict()
        self._children = dict()
        self._scope = self.infer_scope() if scope is None else scope
        # self._scope = 'openpoints'

        # self.build_func will be set with the following priority:
        # 1. build_func
        # 2. parent.build_func
        # 3. build_from_cfg
        if build_func is None:
            if parent is not None:
                self.build_func = parent.build_func
            else:
                self.build_func = build_from_cfg
        else:
            self.build_func = build_func
        if parent is not None:
            assert isinstance(parent, Registry)
            parent._add_children(self)
            self.parent = parent
        else:
            self.parent = None

    def __len__(self):
        return len(self._module_dict)

    def __contains__(self, key):
        return self.get(key) is not None

    def __repr__(self):
        format_str = self.__class__.__name__ + \
                     f'(name={self._name}, ' \
                     f'items={self._module_dict})'
        return format_str

    @staticmethod
    def infer_scope():
        """Infer the scope of registry.
        The name of the package where registry is defined will be returned.
        Example:
            # in mmdet/models/backbone/resnet.py
            >>> MODELS = Registry('models')
            >>> @MODELS.register_module()
            >>> class ResNet:
            >>>     pass
            The scope of ``ResNet`` will be ``mmdet``.
        Returns:
            scope (str): The inferred scope name.
        """
        # inspect.stack() trace where this function is called, the index-2
        # indicates the frame where `infer_scope()` is called
        filename = inspect.getmodule(inspect.stack()[2][0]).__name__
        # filename = 'openpoints.models.build'
 

        split_filename = filename.split('.')  # ['openpoints', 'models', 'build']
        return split_filename[0]  # 'openpoints'

    @staticmethod  # 返回函数的静态方法、声明一个静态方法
    def split_scope_key(key):
        """Split scope and key.
        The first scope will be split from key.
        Examples:
            >>> Registry.split_scope_key('mmdet.ResNet')
            'mmdet', 'ResNet'
            >>> Registry.split_scope_key('ResNet')
            None, 'ResNet'
        Return:
            scope (str, None): The first scope.
            key (str): The remaining key.
        """
        split_index = key.find('.')  
        # 如果没有检测到 key 中包含字符, 则返回 -1; 如果检测到了该字符, 则返回开始时的索引值
        if split_index != -1:
            return key[:split_index], key[split_index + 1:]
        else:
            return None, key

    @property
    def name(self):
        return self._name

    @property
    def scope(self):
        return self._scope

    @property
    def module_dict(self):
        return self._module_dict

    @property
    def children(self):
        return self._children

    def get(self, key):
        # 实现从字符串到类的映射
        # 以字符串 key 映射到 self._module_dict 中注册的类 self._module_dict[real_key]
        """Get the registry record.
        Args:
            key (str): The class name in string format.
        Returns:
            class: The corresponding class.
        """
        scope, real_key = self.split_scope_key(key)  
        # key = BaseSeg; scope = None; real_key = BaseSeg
        if scope is None or scope == self._scope:
            # get from self
            if real_key in self._module_dict:
                return self._module_dict[real_key]
        else:
            # get from self._children
            if scope in self._children:
                return self._children[scope].get(real_key)
            else:
                # goto root
                parent = self.parent
                while parent.parent is not None:
                    parent = parent.parent
                return parent.get(key)

    def build(self, *args, **kwargs):
        return self.build_func(*args, **kwargs, registry=self)

    def _add_children(self, registry):
        """Add children for a registry.
        The ``registry`` will be added as children based on its scope.
        The parent registry could build objects from children registry.
        Example:
            >>> models = Registry('models')
            >>> mmdet_models = Registry('models', parent=models)
            >>> @mmdet_models.register_module()
            >>> class ResNet:
            >>>     pass
            >>> resnet = models.build(dict(NAME='mmdet.ResNet'))
        """

        assert isinstance(registry, Registry)
        assert registry.scope is not None
        assert registry.scope not in self.children, \
            f'scope {registry.scope} exists in {self.name} registry'
        self.children[registry.scope] = registry

    def _register_module(self, module_class, module_name=None, force=False):
        if not inspect.isclass(module_class):
            raise TypeError('module must be a class, '
                            f'but got {type(module_class)}')

        if module_name is None:
            module_name = module_class.__name__
        if isinstance(module_name, str):
            module_name = [module_name]
        for name in module_name:
            if not force and name in self._module_dict:
                raise KeyError(f'{name} is already registered '
                               f'in {self.name}')
            self._module_dict[name] = module_class

    def deprecated_register_module(self, cls=None, force=False):
        warnings.warn(
            'The old API of register_module(module, force=False) '
            'is deprecated and will be removed, please use the new API '
            'register_module(name=None, force=False, module=None) instead.')
        if cls is None:
            return partial(self.deprecated_register_module, force=force)
        self._register_module(cls, force=force)
        return cls

    def register_module(self, name=None, force=False, module=None): 
        # 装饰器
        """Register a module.
        A record will be added to `self._module_dict`, whose key is the class
        name or the specified name, and value is the class itself.
        It can be used as a decorator or a normal function.
        Example:
            >>> backbones = Registry('backbone')
            >>> @backbones.register_module()
            >>> class ResNet:
            >>>     pass
            >>> backbones = Registry('backbone')
            >>> @backbones.register_module(name='mnet')
            >>> class MobileNet:
            >>>     pass
            >>> backbones = Registry('backbone')
            >>> class ResNet:
            >>>     pass
            >>> backbones.register_module(ResNet)
        Args:
            name (str | None): The module name to be registered. If not
                specified, the class name will be used.
            force (bool, optional): Whether to override an existing class with
                the same name. Default: False.
            module (type): Module class to be registered.
        """
        if not isinstance(force, bool):
            raise TypeError(f'force must be a boolean, but got {type(force)}')
        # NOTE: This is a walkaround to be compatible with the old api,
        # while it may introduce unexpected bugs.
        if isinstance(name, type):
            return self.deprecated_register_module(name, force=force)

        # raise the error ahead of time
        if not (name is None or isinstance(name, str) or misc.is_seq_of(name, str)):
            raise TypeError(
                'name must be either of None, an instance of str or a sequence'
                f'  of str, but got {type(name)}')

        # use it as a normal method: x.register_module(module=SomeClass)
        # 正常调用 reister_module, 不是装饰情况
        if module is not None:
            self._register_module(
                module_class=module, module_name=name, force=force)
            return module

        # use it as a decorator: @x.register_module()
        # 这是装饰器的包装函数 wrapper
        # 装饰情况下的调用, cls 就是传递进来的需要被装饰的类 
        def _register(cls):
            self._register_module(
                module_class=cls, module_name=name, force=force)
            return cls

        return _register  # 装饰器返回这个包装函数

def build_from_cfg(cfg, registry, default_args=None):
    """Build a module from config dict.
    Args:
        cfg (edict): Config dict. It should at least contain the key "NAME".
        registry (:obj:`Registry`): The registry to search the type from.
    Returns:
        object: The constructed object.
    """
    if not isinstance(cfg, dict):
        raise TypeError(f'cfg must be a dict, but got {type(cfg)}')
    if 'NAME' not in cfg:
        if default_args is None or 'NAME' not in default_args:
            raise KeyError(
                '`cfg` or `default_args` must contain the key "NAME", '
                f'but got {cfg}\n{default_args}')
    if not isinstance(registry, Registry):
        raise TypeError('registry must be an mmcv.Registry object, '
                        f'but got {type(registry)}')

    if not (isinstance(default_args, dict) or default_args is None):
        raise TypeError('default_args must be a dict or None, '
                        f'but got {type(default_args)}')

    # if default_args is not None:
    #     cfg = config.merge_new_config(cfg, default_args)

    obj_type = cfg.get('NAME')   # 'BaseSeg'

    if isinstance(obj_type, str):
        obj_cls = registry.get(obj_type)  
        # <class 'openpoints.models.segmentation.base_seg.BaseSeg'>
        # 按照名字字符串 从 self._module_dict 找出对应的 类/模块
        # 实现从字符串到类的映射
        if obj_cls is None:
            raise KeyError(
                f'{obj_type} is not in the {registry.name} registry')
    elif inspect.isclass(obj_type):
        obj_cls = obj_type
    else:
        raise TypeError(
            f'type must be a str or valid type, but got {type(obj_type)}')
    try:
        obj_cfg = copy.deepcopy(cfg)
        if default_args is not None:
            obj_cfg.update(default_args) 
        obj_cfg.pop('NAME')  
        # 删除 "NAME" 项, obj_cfg 中留下除了 "NAME" 项的其他项
        # 'NAME' 已完成对类 obj_cls 的映射
        return obj_cls(**obj_cfg)
        # 把变量都展开, 为 openpoints.models.segmentation.base_seg.BaseSeg(**obj_cfg)
        # 又由于 BaseSeg 加了装饰器 @MODELS.register_module() 
        # 相当于调用 MODELS.register_module(module=BaseSeg(**obj_cfg))
        # 其实已经在程序开头注册过了, 所以注册部分在此就没什么作用了
        # 现在开始执行 openpoints.models.segmentation.base_seg.BaseSeg(**obj_cfg)

    except Exception as e:
        # Normal TypeError does not print class name.
        raise type(e)(f'{obj_cls.__name__}: {e}')

2. 类的注册

首先在 openpoints/models/build.py 中声明和定义了全局的注册类对象 MODELS, 称为注册器.

通过 Python 的导入机制 import 命令, 注册器 MODELS 会在程序初始运行时 (先于 __main__/main()) 就建立.

from openpoints.utils import registry
MODELS = registry.Registry('models')
# 创建 register.Registry 对象 (MODELS 也称为注册器), 作为全局变量
# 程序初始运行, 先于 __main__/main() 的执行, 所以程序一开始就建立了注册器 MODELS

def build_model_from_cfg(cfg, **kwargs):
    """
    Build a model, defined by `NAME`.
    Args:
        cfg (eDICT): 
    Returns:
        Model: a constructed model specified by NAME.
    """
    return MODELS.build(cfg, **kwargs)

也是因为 Python 导入机制, 在注册器 MODELS 创立后, openpoints/models 下面在类定义前装饰了 @MODELS.register_module() 的类, 一旦被 import 扫描执行到, 都将被注册到 MODELS 注册器中.

例如下面的 BaseSeg 类也会先注册到 MODELS._module_dict 中.

"""
Author: PointNeXt
"""
import copy
from typing import List
import torch
import torch.nn as nn
import logging
from ...utils import get_missing_parameters_message, get_unexpected_parameters_message
from ..build import MODELS, build_model_from_cfg
from ..layers import create_linearblock, create_convblock1d

# 为类 BaseSeg 加了装饰器 MODELS.register_module
# 调用 BaseSeg() 创建对象时, 效果相当于调用 MODELS.register_module(module=BaseSeg())
# 程序初始运行对类的装饰, 先于 __main__/main(), 但晚于注册器 MODELS 的建立.
# 所以在程序初始部分, 就以完成类的注册了, 待调用 main() 时, 就能顺利利用注册器将字符串转换为类 
@MODELS.register_module()      
class BaseSeg(nn.Module):
    def __init__(self,
                 encoder_args=None,
                 decoder_args=None,
                 cls_args=None,
                 **kwargs):
        super().__init__()

调试过程中, 跟踪查看 MODELS._module_dict 可以发现已经注册了好多类.

MODELS._module_dict = {
'PointNetEncoder': <class 'openpoints.models.backbone.pointnet.PointNetEncoder'>, 
'PointPatchEmbed': <class 'openpoints.models.layers.group_embed.PointPatchEmbed'>, 
'P3Embed': <class 'openpoints.models.layers.group_embed.P3Embed'>,
'PointNet2Encoder': <class 'openpoints.models.backbone.pointnetv2.PointNet2Encoder'>, 
'PointNet2Decoder': <class 'openpoints.models.backbone.pointnetv2.PointNet2Decoder'>, 
'PointNet2PartDecoder': <class 'openpoints.models.backbone.pointnetv2.PointNet2PartDecoder'>, 
'PointNextEncoder': <class 'openpoints.models.backbone.pointnext.PointNextEncoder'>, 
'PointNextDecoder': <class 'openpoints.models.backbone.pointnext.PointNextDecoder'>, 
'PointNextPartDecoder': <class 'openpoints.models.backbone.pointnext.PointNextPartDecoder'>, 
'DGCNN': <class 'openpoints.models.backbone.dgcnn.DGCNN'>, 
'DeepGCN': <class 'openpoints.models.backbone.deepgcn.DeepGCN'>, 
'PointMLPEncoder': <class 'openpoints.models.backbone.pointmlp.PointMLPEncoder'>, 
'PointMLP': <class 'openpoints.models.backbone.pointmlp.PointMLP'>, 
'PointViT': <class 'openpoints.models.backbone.pointvit.PointViT'>, 
'PointViTDecoder': <class 'openpoints.models.backbone.pointvit.PointViTDecoder'>, 
'PointViTPartDecoder': <class 'openpoints.models.backbone.pointvit.PointViTPartDecoder'>, 
'InvPointViT': <class 'openpoints.models.backbone.pointvit_inv.InvPointViT'>, 
'InvPointViTDecoder': <class 'openpoints.models.backbone.pointvit_inv.InvPointViTDecoder'>, 
'InvPointViTPartDecoder': <class 'openpoints.models.backbone.pointvit_inv.InvPointViTPartDecoder'>, 
'CurveNet': <class 'openpoints.models.backbone.curvenet.CurveNet'>, 
'MVFC': <class 'openpoints.models.backbone.simpleview.MVFC'>, 
'MVModel': <class 'openpoints.models.backbone.simpleview.MVModel'>, 
'BaseSeg': <class 'openpoints.models.segmentation.base_seg.BaseSeg'>, 
'BasePartSeg': <class 'openpoints.models.segmentation.base_seg.BasePartSeg'>, 
'VariableSeg': <class 'openpoints.models.segmentation.base_seg.VariableSeg'>, 
'SegHead': <class 'openpoints.models.segmentation.base_seg.SegHead'>, 
'VariableSegHead': <class 'openpoints.models.segmentation.base_seg.VariableSegHead'>, 
'MultiSegHead': <class 'openpoints.models.segmentation.base_seg.MultiSegHead'>, 
'BaseCls': <class 'openpoints.models.classification.cls_base.BaseCls'>, 
'DistillCls': <class 'openpoints.models.classification.cls_base.DistillCls'>, 
'ClsHead': <class 'openpoints.models.classification.cls_base.ClsHead'>, 
'MaskedTransformerDecoder': <class 'openpoints.models.reconstruction.base_recontruct.MaskedTransformerDecoder'>, 
'FoldingNet': <class 'openpoints.models.reconstruction.base_recontruct.FoldingNet'>, 
'NodeShuffle': <class 'openpoints.models.reconstruction.base_recontruct.NodeShuffle'>, 
'MaskedPointViT': <class 'openpoints.models.reconstruction.maskedpointvit.MaskedPointViT'>, 
'MaskedPoint': <class 'openpoints.models.reconstruction.maskedpoint.MaskedPoint'>, 
'MaskedPointGroup': <class 'openpoints.models.reconstruction.maskedpointgroup.MaskedPointGroup'>
}

3. 注册应用

有了注册器 MODELS, 并向其注册了各个类, 那么就可以应用其由字符串映射为类的功能, 方便地从 .yaml 文件配置实现类实例的创建.

初略时序如下图所示:

examples/segmentation/main() openpoints/models/build.py class Registry openpoints/utils/registry.py 创建实例 registry.Registry('models') __init__(), self.build_func = build_from_cfg 全局对象 MODELS (注册器) build_model_from_cfg(cfg.model) MODELS.build(cfg, **kwargs) build(self, *args, **kwargs) build_from_cfg(cfg, registry, default_args=None) return obj_cls(**obj_cfg) [相当于 BaseSeg(**obj_cfg)] model examples/segmentation/main() openpoints/models/build.py class Registry
Fig 1. 利用注册器创建类对象 (深度神经网络模型) 的时序

其中由 .yaml 文件读取获得的配置字典变量 cfg 中存在 NAME 条目, 通过 registry.get(*) 就能获得 NAME 字符串对应的已经注册了的类. 获得了对应的类后, NAME 条目完成使命, 剩下的其他配置条目将被用于 PointNeXT 中具体的深度神将网络模块/类的自动化配置构造 (这篇博文不涉及).

细节注释参看类 Registry 的方法 build_from_cfg, 重复如下:

def build_from_cfg(cfg, registry, default_args=None):
    """Build a module from config dict.
    Args:
        cfg (edict): Config dict. It should at least contain the key "NAME".
        registry (:obj:`Registry`): The registry to search the type from.
    Returns:
        object: The constructed object.
    """
    if not isinstance(cfg, dict):
        raise TypeError(f'cfg must be a dict, but got {type(cfg)}')
    if 'NAME' not in cfg:
        if default_args is None or 'NAME' not in default_args:
            raise KeyError(
                '`cfg` or `default_args` must contain the key "NAME", '
                f'but got {cfg}\n{default_args}')
    if not isinstance(registry, Registry):
        raise TypeError('registry must be an mmcv.Registry object, '
                        f'but got {type(registry)}')

    if not (isinstance(default_args, dict) or default_args is None):
        raise TypeError('default_args must be a dict or None, '
                        f'but got {type(default_args)}')

    # if default_args is not None:
    #     cfg = config.merge_new_config(cfg, default_args)

    obj_type = cfg.get('NAME')   # 'BaseSeg'

    if isinstance(obj_type, str):
        obj_cls = registry.get(obj_type)  
        # <class 'openpoints.models.segmentation.base_seg.BaseSeg'>
        # 按照名字字符串 从 self._module_dict 找出对应的 类/模块
        # 实现从字符串到类的映射
        if obj_cls is None:
            raise KeyError(
                f'{obj_type} is not in the {registry.name} registry')
    elif inspect.isclass(obj_type):
        obj_cls = obj_type
    else:
        raise TypeError(
            f'type must be a str or valid type, but got {type(obj_type)}')
    try:
        obj_cfg = copy.deepcopy(cfg)
        if default_args is not None:
            obj_cfg.update(default_args) 
        obj_cfg.pop('NAME')  
        # 删除 "NAME" 项, obj_cfg 中留下除了 "NAME" 项的其他项
        # 'NAME' 已完成对类 obj_cls 的映射
        return obj_cls(**obj_cfg)
        # 把变量都展开, 为 openpoints.models.segmentation.base_seg.BaseSeg(**obj_cfg)
        # 又由于 BaseSeg 加了装饰器 @MODELS.register_module() 
        # 相当于调用 MODELS.register_module(module=BaseSeg(**obj_cfg))
        # 其实已经在程序开头注册过了, 所以注册部分在此就没什么作用了
        # 现在开始执行 openpoints.models.segmentation.base_seg.BaseSeg(**obj_cfg)

II. 参数解析

注册机制需要字符串参数的传入以构建类实例. 而参数的获得需要借助于解析过程将 .yaml 文件中的配置读入程序中.

1. 命令行解析

主程序部分先要将相关的 .yaml 文件读入并更新到 cfg 字典变量中, 注释如下.

if __name__ == "__main__":
    parser = argparse.ArgumentParser('Scene segmentation training/testing')
    # 创建解析器
    parser.add_argument('--cfg', type=str, required=True, help='config file')
    parser.add_argument('--profile', action='store_true', default=False, help='set to True to profile speed')
    # 添加参数
    args, opts = parser.parse_known_args()
    # CUDA_VISIBLE_DEVICES=0,1 python examples/segmentation/main.py --cfg cfgs/s3dis/pointnext-s.yaml  mode=train
    # 其中 CUDA_VISIBLE_DEVICES=0,1 为环境变量, 不由 parser 解析
    # args = Namespace(cfg='cfgs/s3dis/pointnext-s.yaml', profile=False)
    # opts = ['mode=train']

    cfg = EasyConfig()
    cfg.load(args.cfg, recursive=True)  # args.cfg = cfs/s3dis/pointnext-s.yaml
    cfg.update(opts)  
    # overwrite the default arguments in yml
    # mode = train 更新入 cfg 字典

    if cfg.seed is None:
        cfg.seed = np.random.randint(1, 10000)

    # init distributed env first, since logger depends on the dist info.
    cfg.rank, cfg.world_size, cfg.distributed, cfg.mp = dist_utils.get_dist_info(cfg)
    cfg.sync_bn = cfg.world_size > 1  # debug 时, 只能单块 GPU; 正常运行时, 可以多块并行

    # init log dir
    cfg.task_name = args.cfg.split('.')[-2].split('/')[-2]  
    # task/dataset name, \eg s3dis, modelnet40_cls
    # args.cfg = 'cfgs/s3dis/pointnext-s.yaml'
    # args.cfg.split('.')[-2] = 'cfgs/s3dis/pointnext-s'
    # args.cfg.split('.')[-2].split('/')[-2] = 's3dis'
    cfg.cfg_basename = args.cfg.split('.')[-2].split('/')[-1]  
    # cfg_basename, \eg pointnext-xl\
    # args.cfg.split('.')[-2].split('/')[-1] = 'pointnext-s'
    tags = [
        cfg.task_name,  # task name (the folder of name under ./cfgs
        cfg.mode,
        cfg.cfg_basename,  # cfg file name
        f'ngpus{cfg.world_size}',
    ]
    # tags = ['s3dis', 'train', 'pointnext-s', 'ngpus1']
    opt_list = [] # for checking experiment configs from logging file
    for i, opt in enumerate(opts):
        if 'rank' not in opt and 'dir' not in opt and 'root' not in opt and 'pretrain' not in opt and 'path' not in opt and 'wandb' not in opt and '/' not in opt:
            opt_list.append(opt)
    cfg.root_dir = os.path.join(cfg.root_dir, cfg.task_name)
    cfg.opts = '-'.join(opt_list)  # 使用'-'作分隔符来进行join

    cfg.is_training = cfg.mode not in ['test', 'testing', 'val', 'eval', 'evaluation']
    if cfg.mode in ['resume', 'val', 'test']:
        resume_exp_directory(cfg, pretrained_path=cfg.pretrained_path)  
        # 需要命令行 加 pretrained_path=XXX
        cfg.wandb.tags = [cfg.mode]
    else:
        generate_exp_directory(cfg, tags, additional_id=os.environ.get('MASTER_PORT', None))
        cfg.wandb.tags = tags
    os.environ["JOB_LOG_DIR"] = cfg.log_dir
    cfg_path = os.path.join(cfg.run_dir, "cfg.yaml")
    # cfg_path = 'log/s3dis/s3dis-train-pointnext-s-ngpus1-20240730-092203-hQtDgCBNbQaYpAMwLVn9TC/cfg.yaml'
    with open(cfg_path, 'w') as f:
        yaml.dump(cfg, f, indent=2)  # cfg 写入 f 文件
        os.system('cp %s %s' % (args.cfg, cfg.run_dir))
        # args.cfg = 'cfgs/s3dis/pointnext-s.yaml'
        # cfg.run_dir = 'log/s3dis/s3dis-train-pointnext-s-ngpus1-20240730-092203-hQtDgCBNbQaYpAMwLVn9TC'
    cfg.cfg_path = cfg_path

    # wandb config
    cfg.wandb.name = cfg.run_name
    # cfg.run_name = 's3dis-train-pointnext-s-ngpus1-20240730-092203-hQtDgCBNbQaYpAMwLVn9TC'

    # multi processing.
    if cfg.mp:
        port = find_free_port()
        cfg.dist_url = f"tcp://localhost:{port}"
        print('using mp spawn for distributed training')
        mp.spawn(main, nprocs=cfg.world_size, args=(cfg,))
    else:
        main(0, cfg)

2. 参数加载更新

配置条目的读入和更新在类 EasyConfig 中实现, 部分注释如下.

class EasyConfig(dict):
    def __getattr__(self, key: str) -> Any:
        if key not in self:
            raise AttributeError(key)
        return self[key]

    def __setattr__(self, key: str, value: Any) -> None:
        self[key] = value

    def __delattr__(self, key: str) -> None:
        del self[key]

    def load(self, fpath: str, *, recursive: bool = False) -> None:
        """load cfg from yaml

        Args:
            fpath (str): path to the yaml file
            recursive (bool, optional): recursily load its parent defaul yaml files. Defaults to False.
        """
        if not os.path.exists(fpath):
            raise FileNotFoundError(fpath)
        fpaths = [fpath]
        # 'cfgs/s3dis/pointnext-s.yaml'
        if recursive:  # True
            extension = os.path.splitext(fpath)[1]   # .yaml
            while os.path.dirname(fpath) != fpath:   # 如果 fpath 是文件路径
                fpath = os.path.dirname(fpath)  # 去掉文件名, 返回目录, 每次脱去一级
                fpaths.append(os.path.join(fpath, 'default' + extension))   
                #  fpaths =['cfgs/s3dis/pointnext-s.yaml', 
                #           'cfgs/s3dis/default.yaml', 
                #           'cfgs/default.yaml', 
                #           'default.yaml']
        for fpath in reversed(fpaths):   # 反转迭代器
            if os.path.exists(fpath):
                with open(fpath) as f:
                    self.update(yaml.safe_load(f))   
                    # 把 fpaths 中的所有 .yaml 文件中的配置条目写在一个 dict 变量中

    def reload(self, fpath: str, *, recursive: bool = False) -> None:
        self.clear()
        self.load(fpath, recursive=recursive)

    # mutimethod makes python supports function overloading
    @multimethod
    def update(self, other: Dict) -> None:  
    # .yaml items 转为 dict 变量中的 key:value 对
        for key, value in other.items():
            if isinstance(value, dict):
                if key not in self or not isinstance(self[key], EasyConfig):  
                	# 子条目
                    self[key] = EasyConfig()
                # recursively update
                self[key].update(value)
            else:
                self[key] = value

    @multimethod
    def update(self, opts: Union[List, Tuple]) -> None:
        index = 0
        while index < len(opts):
            opt = opts[index]
            if opt.startswith('--'):
                opt = opt[2:]
            if '=' in opt:
                key, value = opt.split('=', 1)
                index += 1
            else:
                key, value = opt, opts[index + 1]
                index += 2
            current = self
            subkeys = key.split('.')
            try:
                value = literal_eval(value)
            except:
                pass
            for subkey in subkeys[:-1]:
                current = current.setdefault(subkey, EasyConfig())
            current[subkeys[-1]] = value

    def dict(self) -> Dict[str, Any]:
        configs = dict()
        for key, value in self.items():
            if isinstance(value, EasyConfig):
                value = value.dict()
            configs[key] = value
        return configs

    def hash(self) -> str:
        buffer = json.dumps(self.dict(), sort_keys=True)
        return hashlib.sha256(buffer.encode()).hexdigest()

    def __str__(self) -> str:
        texts = []
        for key, value in self.items():
            if isinstance(value, EasyConfig):
                seperator = '\n'
            else:
                seperator = ' '
            text = key + ':' + seperator + str(value)
            lines = text.split('\n')
            for k, line in enumerate(lines[1:]):
                lines[k + 1] = (' ' * 2) + line
            texts.extend(lines)
        return '\n'.join(texts)

3. 获得的参数

fpaths = ['cfgs/s3dis/pointnext-s.yaml', 'cfgs/s3dis/default.yaml', 'cfgs/default.yaml', 'default.yaml'] 所含全部 .yaml 文件 (如存在, 其中 default.yaml 不存在) 内的所有条目解析并写入 cfg 字典变量.
参数解析后得到的字典变量 cfg 如下, 其中 cfg.model 部分将被用于网络模型 (类实现) 的自动化配置与构建.

dist_url: tcp://localhost:8888
dist_backend: nccl
multiprocessing_distributed: False
ngpus_per_node: 1
world_size: 1
launcher: mp
local_rank: 0
use_gpu: True
seed: 3392
epoch: 0
epochs: 100
ignore_index: None
val_fn: validate
deterministic: False
sync_bn: False
criterion_args:
  NAME: CrossEntropy
  label_smoothing: 0.2
use_mask: False
grad_norm_clip: 10
layer_decay: 0
step_per_update: 1
start_epoch: 1
sched_on_epoch: True
wandb:
  use_wandb: False
  project: PointNeXt-S3DIS
  tags: ['s3dis', 'train', 'pointnext-s', 'ngpus1']
  name: s3dis-train-pointnext-s-ngpus1-20240802-170230-MzNANkH27yLSBYpmpmF5tu
use_amp: False
use_voting: False
val_freq: 1
resume: False
test: False
finetune: False
mode: train
logname: None
load_path: None
print_freq: 50
save_freq: -1
root_dir: log/s3dis
pretrained_path: None
datatransforms:
  train: ['ChromaticAutoContrast', 'PointsToTensor', 'PointCloudScaling', 'PointCloudXYZAlign', 'PointCloudJitter', 'ChromaticDropGPU', 'ChromaticNormalize']
  val: ['PointsToTensor', 'PointCloudXYZAlign', 'ChromaticNormalize']
  vote: ['ChromaticDropGPU']
  kwargs:
    color_drop: 0.2
    gravity_dim: 2
    scale: [0.9, 1.1]
    angle: [0, 0, 1]
    jitter_sigma: 0.005
    jitter_clip: 0.02
feature_keys: x,heights
dataset:
  common:
    NAME: S3DIS
    data_root: data/S3DIS/s3disfull
    test_area: 5
    voxel_size: 0.04
  train:
    split: train
    voxel_max: 24000
    loop: 30
    presample: False
  val:
    split: val
    voxel_max: None
    presample: True
  test:
    split: test
    voxel_max: None
    presample: False
num_classes: 13
batch_size: 32
val_batch_size: 1
dataloader:
  num_workers: 6
cls_weighed_loss: False
optimizer:
  NAME: adamw
  weight_decay: 0.0001
sched: cosine
warmup_epochs: 0
min_lr: 1e-05
lr: 0.01
log_dir: log/s3dis/s3dis-train-pointnext-s-ngpus1-20240802-170230-MzNANkH27yLSBYpmpmF5tu
model:
  NAME: BaseSeg
  encoder_args:
    NAME: PointNextEncoder
    blocks: [1, 1, 1, 1, 1]
    strides: [1, 4, 4, 4, 4]
    sa_layers: 2
    sa_use_res: True
    width: 32
    in_channels: 4
    expansion: 4
    radius: 0.1
    nsample: 32
    aggr_args:
      feature_type: dp_fj
      reduction: max
    group_args:
      NAME: ballquery
      normalize_dp: True
    conv_args:
      order: conv-norm-act
    act_args:
      act: relu
    norm_args:
      norm: bn
  decoder_args:
    NAME: PointNextDecoder
  cls_args:
    NAME: SegHead
    num_classes: 13
    in_channels: None
    norm_args:
      norm: bn
  in_channels: 4
rank: 0
distributed: False
mp: False
task_name: s3dis
cfg_basename: pointnext-s
opts: mode=train
is_training: True
run_name: s3dis-train-pointnext-s-ngpus1-20240802-170230-MzNANkH27yLSBYpmpmF5tu
run_dir: log/s3dis/s3dis-train-pointnext-s-ngpus1-20240802-170230-MzNANkH27yLSBYpmpmF5tu
exp_dir: log/s3dis/s3dis-train-pointnext-s-ngpus1-20240802-170230-MzNANkH27yLSBYpmpmF5tu
ckpt_dir: log/s3dis/s3dis-train-pointnext-s-ngpus1-20240802-170230-MzNANkH27yLSBYpmpmF5tu/checkpoint
log_path: log/s3dis/s3dis-train-pointnext-s-ngpus1-20240802-170230-MzNANkH27yLSBYpmpmF5tu/s3dis-train-pointnext-s-ngpus1-20240802-170230-MzNANkH27yLSBYpmpmF5tu.log
cfg_path: log/s3dis/s3dis-train-pointnext-s-ngpus1-20240802-170230-MzNANkH27yLSBYpmpmF5tu/cfg.yaml

III. 总结

1. 结果

examples/segmentation/main.pymain() 函数中建立深度网络模型 (类实现) 的部分代码:

    if cfg.model.get('in_channels', None) is None:
        cfg.model.in_channels = cfg.model.encoder_args.in_channels # 4
    model = build_model_from_cfg(cfg.model).to(cfg.rank)
    model_size = cal_model_parm_nums(model)
    logging.info(model)
    logging.info('Number of params: %.4f M' % (model_size / 1e6))

通过 build_model_from_cfg(cfg.model) 调用, 进而执行 openpoints.models.segmentation.base_seg.BaseSeg(**obj_cfg), 获得网络模型结构:

 BaseSeg(
  (encoder): PointNextEncoder(
    (encoder): Sequential(
      (0): Sequential(
        (0): SetAbstraction(
          (convs): Sequential(
            (0): Sequential(
              (0): Conv1d(4, 32, kernel_size=(1,), stride=(1,))
            )
          )
        )
      )
      (1): Sequential(
        (0): SetAbstraction(
          (skipconv): Sequential(
            (0): Conv1d(32, 64, kernel_size=(1,), stride=(1,))
          )
          (act): ReLU(inplace=True)
          (convs): Sequential(
            (0): Sequential(
              (0): Conv2d(35, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): ReLU(inplace=True)
            )
            (1): Sequential(
              (0): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            )
          )
          (grouper): QueryAndGroup()
        )
      )
      (2): Sequential(
        (0): SetAbstraction(
          (skipconv): Sequential(
            (0): Conv1d(64, 128, kernel_size=(1,), stride=(1,))
          )
          (act): ReLU(inplace=True)
          (convs): Sequential(
            (0): Sequential(
              (0): Conv2d(67, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): ReLU(inplace=True)
            )
            (1): Sequential(
              (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            )
          )
          (grouper): QueryAndGroup()
        )
      )
      (3): Sequential(
        (0): SetAbstraction(
          (skipconv): Sequential(
            (0): Conv1d(128, 256, kernel_size=(1,), stride=(1,))
          )
          (act): ReLU(inplace=True)
          (convs): Sequential(
            (0): Sequential(
              (0): Conv2d(131, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): ReLU(inplace=True)
            )
            (1): Sequential(
              (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            )
          )
          (grouper): QueryAndGroup()
        )
      )
      (4): Sequential(
        (0): SetAbstraction(
          (skipconv): Sequential(
            (0): Conv1d(256, 512, kernel_size=(1,), stride=(1,))
          )
          (act): ReLU(inplace=True)
          (convs): Sequential(
            (0): Sequential(
              (0): Conv2d(259, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): ReLU(inplace=True)
            )
            (1): Sequential(
              (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            )
          )
          (grouper): QueryAndGroup()
        )
      )
    )
  )
  (decoder): PointNextDecoder(
    (decoder): Sequential(
      (0): Sequential(
        (0): FeaturePropogation(
          (convs): Sequential(
            (0): Sequential(
              (0): Conv1d(96, 32, kernel_size=(1,), stride=(1,), bias=False)
              (1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): ReLU(inplace=True)
            )
            (1): Sequential(
              (0): Conv1d(32, 32, kernel_size=(1,), stride=(1,), bias=False)
              (1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): ReLU(inplace=True)
            )
          )
        )
      )
      (1): Sequential(
        (0): FeaturePropogation(
          (convs): Sequential(
            (0): Sequential(
              (0): Conv1d(192, 64, kernel_size=(1,), stride=(1,), bias=False)
              (1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): ReLU(inplace=True)
            )
            (1): Sequential(
              (0): Conv1d(64, 64, kernel_size=(1,), stride=(1,), bias=False)
              (1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): ReLU(inplace=True)
            )
          )
        )
      )
      (2): Sequential(
        (0): FeaturePropogation(
          (convs): Sequential(
            (0): Sequential(
              (0): Conv1d(384, 128, kernel_size=(1,), stride=(1,), bias=False)
              (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): ReLU(inplace=True)
            )
            (1): Sequential(
              (0): Conv1d(128, 128, kernel_size=(1,), stride=(1,), bias=False)
              (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): ReLU(inplace=True)
            )
          )
        )
      )
      (3): Sequential(
        (0): FeaturePropogation(
          (convs): Sequential(
            (0): Sequential(
              (0): Conv1d(768, 256, kernel_size=(1,), stride=(1,), bias=False)
              (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): ReLU(inplace=True)
            )
            (1): Sequential(
              (0): Conv1d(256, 256, kernel_size=(1,), stride=(1,), bias=False)
              (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): ReLU(inplace=True)
            )
          )
        )
      )
    )
  )
  (head): SegHead(
    (head): Sequential(
      (0): Sequential(
        (0): Conv1d(32, 32, kernel_size=(1,), stride=(1,), bias=False)
        (1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
      )
      (1): Dropout(p=0.5, inplace=False)
      (2): Sequential(
        (0): Conv1d(32, 13, kernel_size=(1,), stride=(1,))
      )
    )
  )
) 

2. Todo

以上网络结构如何自动化地配置与构造? 待阅读源码学习和理解.

感谢论文和代码作者开源研究成果 !


评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值