Title: 三维点云深度网络 PointNeXt 源码阅读 (I) —— 注册机制与参数解析
关联博文
[1] 三维点云深度网络 PointNeXt 的安装配置与测试
[2] 三维点云深度网络 PointNeXt 源码阅读 (I) —— 注册机制与参数解析 ⇐ \qquad \Leftarrow ⇐ 本篇
[3] 三维点云深度网络 PointNeXt 源码阅读 (II) —— 点云数据集构造与预处理
[4] 三维点云深度网络 PointNeXt 源码阅读 (III) —— 骨干网络模型
[5] 三维点云深度网络 PointNeXt 源码阅读 (IV) —— PointNeXt-B
前言
学习了部分 PointNeXt 源码, 先记录一下, 以备忘.
本篇博文分为两部分, 注册机制和参数解析, 理解的重点是注册机制.
相关注释和调试信息都是基于下面测试 session.
CUDA_VISIBLE_DEVICES=0,1 python examples/segmentation/main.py \
--cfg cfgs/s3dis/pointnext-s.yaml mode=train
I. 注册机制
所谓注册机制是指 PointNeXt 中模块/类的注册机制, 可以实现字符串到模块/类的映射. 换而言之, 这种注册机制就可以实现读入配置文件中的参数字符串, 进而直接映射获得对应的模块/类的实例. 这部分的实现 PointNeXt 源作者参考了 mmcv 中的注册机制.
1. 注册类 Registry
注册机制本身是通过注册类 class Registry
实现的, 其中关键方法有:
方法 | 解释 |
---|---|
__init__() | 类初始化, 其中也初始化了注册模块字典 self._module_dict = dict() |
get(self, key) | 实现从字符串到类的映射, 以字符串 key 映射到 self._module_dict 中注册的类 self._module_dict[real_key] |
register_module(self, name=None, force=False, module=None) | 注册模块, 实现对模块/类的注册, 也用作为对模块/类进行装饰的装饰器 |
_register(cls) | 装饰器 register_module 内部的包装函数 wrapper. 适用于装饰情况下的调用, 参数 cls 就是传递进来的需要被装饰的类. 这个包装函数在 cls 类定义的基础上, 先调用_register_module(self, module_class, module_name=None, force=False) 实现了类的注册 self._module_dict[name] = module_class , 然后没有任何其他处理而直接 return 了类定义 cls |
build_from_cfg(cfg, registry, default_args=None) | 从配置字典构建模块/类实例, 实现由字符串生成模块/类实例 |
openpoints/utils/registry.py
中定义了 Registry 类, 添加注释如下.
# Acknowledgement: built upon mmcv
import inspect
import warnings
from functools import partial
import copy
class Registry:
"""A registry to map strings to classes.
Registered object could be built from registry.
Example:
>>> MODELS = Registry('models')
>>> @MODELS.register_module()
>>> class ResNet:
>>> pass
>>> resnet = MODELS.build(dict(NAME='ResNet'))
Please refer to https://mmcv.readthedocs.io/en/latest/registry.html for
advanced useage.
Args:
name (str): Registry name.
build_func(func, optional): Build function to construct instance from
Registry, func:`build_from_cfg` is used if neither ``parent`` or
``build_func`` is specified. If ``parent`` is specified and
``build_func`` is not given, ``build_func`` will be inherited
from ``parent``. Default: None.
parent (Registry, optional): Parent registry. The class registered in
children registry could be built from parent. Default: None.
scope (str, optional): The scope of registry. It is the key to search
for children registry. If not specified, scope will be the name of
the package where class is defined, e.g. mmdet, mmcls, mmseg.
Default: None.
"""
def __init__(self, name, build_func=None, parent=None, scope=None):
self._name = name
self._module_dict = dict()
self._children = dict()
self._scope = self.infer_scope() if scope is None else scope
# self._scope = 'openpoints'
# self.build_func will be set with the following priority:
# 1. build_func
# 2. parent.build_func
# 3. build_from_cfg
if build_func is None:
if parent is not None:
self.build_func = parent.build_func
else:
self.build_func = build_from_cfg
else:
self.build_func = build_func
if parent is not None:
assert isinstance(parent, Registry)
parent._add_children(self)
self.parent = parent
else:
self.parent = None
def __len__(self):
return len(self._module_dict)
def __contains__(self, key):
return self.get(key) is not None
def __repr__(self):
format_str = self.__class__.__name__ + \
f'(name={self._name}, ' \
f'items={self._module_dict})'
return format_str
@staticmethod
def infer_scope():
"""Infer the scope of registry.
The name of the package where registry is defined will be returned.
Example:
# in mmdet/models/backbone/resnet.py
>>> MODELS = Registry('models')
>>> @MODELS.register_module()
>>> class ResNet:
>>> pass
The scope of ``ResNet`` will be ``mmdet``.
Returns:
scope (str): The inferred scope name.
"""
# inspect.stack() trace where this function is called, the index-2
# indicates the frame where `infer_scope()` is called
filename = inspect.getmodule(inspect.stack()[2][0]).__name__
# filename = 'openpoints.models.build'
split_filename = filename.split('.') # ['openpoints', 'models', 'build']
return split_filename[0] # 'openpoints'
@staticmethod # 返回函数的静态方法、声明一个静态方法
def split_scope_key(key):
"""Split scope and key.
The first scope will be split from key.
Examples:
>>> Registry.split_scope_key('mmdet.ResNet')
'mmdet', 'ResNet'
>>> Registry.split_scope_key('ResNet')
None, 'ResNet'
Return:
scope (str, None): The first scope.
key (str): The remaining key.
"""
split_index = key.find('.')
# 如果没有检测到 key 中包含字符, 则返回 -1; 如果检测到了该字符, 则返回开始时的索引值
if split_index != -1:
return key[:split_index], key[split_index + 1:]
else:
return None, key
@property
def name(self):
return self._name
@property
def scope(self):
return self._scope
@property
def module_dict(self):
return self._module_dict
@property
def children(self):
return self._children
def get(self, key):
# 实现从字符串到类的映射
# 以字符串 key 映射到 self._module_dict 中注册的类 self._module_dict[real_key]
"""Get the registry record.
Args:
key (str): The class name in string format.
Returns:
class: The corresponding class.
"""
scope, real_key = self.split_scope_key(key)
# key = BaseSeg; scope = None; real_key = BaseSeg
if scope is None or scope == self._scope:
# get from self
if real_key in self._module_dict:
return self._module_dict[real_key]
else:
# get from self._children
if scope in self._children:
return self._children[scope].get(real_key)
else:
# goto root
parent = self.parent
while parent.parent is not None:
parent = parent.parent
return parent.get(key)
def build(self, *args, **kwargs):
return self.build_func(*args, **kwargs, registry=self)
def _add_children(self, registry):
"""Add children for a registry.
The ``registry`` will be added as children based on its scope.
The parent registry could build objects from children registry.
Example:
>>> models = Registry('models')
>>> mmdet_models = Registry('models', parent=models)
>>> @mmdet_models.register_module()
>>> class ResNet:
>>> pass
>>> resnet = models.build(dict(NAME='mmdet.ResNet'))
"""
assert isinstance(registry, Registry)
assert registry.scope is not None
assert registry.scope not in self.children, \
f'scope {registry.scope} exists in {self.name} registry'
self.children[registry.scope] = registry
def _register_module(self, module_class, module_name=None, force=False):
if not inspect.isclass(module_class):
raise TypeError('module must be a class, '
f'but got {type(module_class)}')
if module_name is None:
module_name = module_class.__name__
if isinstance(module_name, str):
module_name = [module_name]
for name in module_name:
if not force and name in self._module_dict:
raise KeyError(f'{name} is already registered '
f'in {self.name}')
self._module_dict[name] = module_class
def deprecated_register_module(self, cls=None, force=False):
warnings.warn(
'The old API of register_module(module, force=False) '
'is deprecated and will be removed, please use the new API '
'register_module(name=None, force=False, module=None) instead.')
if cls is None:
return partial(self.deprecated_register_module, force=force)
self._register_module(cls, force=force)
return cls
def register_module(self, name=None, force=False, module=None):
# 装饰器
"""Register a module.
A record will be added to `self._module_dict`, whose key is the class
name or the specified name, and value is the class itself.
It can be used as a decorator or a normal function.
Example:
>>> backbones = Registry('backbone')
>>> @backbones.register_module()
>>> class ResNet:
>>> pass
>>> backbones = Registry('backbone')
>>> @backbones.register_module(name='mnet')
>>> class MobileNet:
>>> pass
>>> backbones = Registry('backbone')
>>> class ResNet:
>>> pass
>>> backbones.register_module(ResNet)
Args:
name (str | None): The module name to be registered. If not
specified, the class name will be used.
force (bool, optional): Whether to override an existing class with
the same name. Default: False.
module (type): Module class to be registered.
"""
if not isinstance(force, bool):
raise TypeError(f'force must be a boolean, but got {type(force)}')
# NOTE: This is a walkaround to be compatible with the old api,
# while it may introduce unexpected bugs.
if isinstance(name, type):
return self.deprecated_register_module(name, force=force)
# raise the error ahead of time
if not (name is None or isinstance(name, str) or misc.is_seq_of(name, str)):
raise TypeError(
'name must be either of None, an instance of str or a sequence'
f' of str, but got {type(name)}')
# use it as a normal method: x.register_module(module=SomeClass)
# 正常调用 reister_module, 不是装饰情况
if module is not None:
self._register_module(
module_class=module, module_name=name, force=force)
return module
# use it as a decorator: @x.register_module()
# 这是装饰器的包装函数 wrapper
# 装饰情况下的调用, cls 就是传递进来的需要被装饰的类
def _register(cls):
self._register_module(
module_class=cls, module_name=name, force=force)
return cls
return _register # 装饰器返回这个包装函数
def build_from_cfg(cfg, registry, default_args=None):
"""Build a module from config dict.
Args:
cfg (edict): Config dict. It should at least contain the key "NAME".
registry (:obj:`Registry`): The registry to search the type from.
Returns:
object: The constructed object.
"""
if not isinstance(cfg, dict):
raise TypeError(f'cfg must be a dict, but got {type(cfg)}')
if 'NAME' not in cfg:
if default_args is None or 'NAME' not in default_args:
raise KeyError(
'`cfg` or `default_args` must contain the key "NAME", '
f'but got {cfg}\n{default_args}')
if not isinstance(registry, Registry):
raise TypeError('registry must be an mmcv.Registry object, '
f'but got {type(registry)}')
if not (isinstance(default_args, dict) or default_args is None):
raise TypeError('default_args must be a dict or None, '
f'but got {type(default_args)}')
# if default_args is not None:
# cfg = config.merge_new_config(cfg, default_args)
obj_type = cfg.get('NAME') # 'BaseSeg'
if isinstance(obj_type, str):
obj_cls = registry.get(obj_type)
# <class 'openpoints.models.segmentation.base_seg.BaseSeg'>
# 按照名字字符串 从 self._module_dict 找出对应的 类/模块
# 实现从字符串到类的映射
if obj_cls is None:
raise KeyError(
f'{obj_type} is not in the {registry.name} registry')
elif inspect.isclass(obj_type):
obj_cls = obj_type
else:
raise TypeError(
f'type must be a str or valid type, but got {type(obj_type)}')
try:
obj_cfg = copy.deepcopy(cfg)
if default_args is not None:
obj_cfg.update(default_args)
obj_cfg.pop('NAME')
# 删除 "NAME" 项, obj_cfg 中留下除了 "NAME" 项的其他项
# 'NAME' 已完成对类 obj_cls 的映射
return obj_cls(**obj_cfg)
# 把变量都展开, 为 openpoints.models.segmentation.base_seg.BaseSeg(**obj_cfg)
# 又由于 BaseSeg 加了装饰器 @MODELS.register_module()
# 相当于调用 MODELS.register_module(module=BaseSeg(**obj_cfg))
# 其实已经在程序开头注册过了, 所以注册部分在此就没什么作用了
# 现在开始执行 openpoints.models.segmentation.base_seg.BaseSeg(**obj_cfg)
except Exception as e:
# Normal TypeError does not print class name.
raise type(e)(f'{obj_cls.__name__}: {e}')
2. 类的注册
首先在 openpoints/models/build.py
中声明和定义了全局的注册类对象 MODELS, 称为注册器.
通过 Python 的导入机制 import 命令, 注册器 MODELS 会在程序初始运行时 (先于 __main__
/main()
) 就建立.
from openpoints.utils import registry
MODELS = registry.Registry('models')
# 创建 register.Registry 对象 (MODELS 也称为注册器), 作为全局变量
# 程序初始运行, 先于 __main__/main() 的执行, 所以程序一开始就建立了注册器 MODELS
def build_model_from_cfg(cfg, **kwargs):
"""
Build a model, defined by `NAME`.
Args:
cfg (eDICT):
Returns:
Model: a constructed model specified by NAME.
"""
return MODELS.build(cfg, **kwargs)
也是因为 Python 导入机制, 在注册器 MODELS 创立后, openpoints/models
下面在类定义前装饰了 @MODELS.register_module()
的类, 一旦被 import 扫描执行到, 都将被注册到 MODELS 注册器中.
例如下面的 BaseSeg
类也会先注册到 MODELS._module_dict
中.
"""
Author: PointNeXt
"""
import copy
from typing import List
import torch
import torch.nn as nn
import logging
from ...utils import get_missing_parameters_message, get_unexpected_parameters_message
from ..build import MODELS, build_model_from_cfg
from ..layers import create_linearblock, create_convblock1d
# 为类 BaseSeg 加了装饰器 MODELS.register_module
# 调用 BaseSeg() 创建对象时, 效果相当于调用 MODELS.register_module(module=BaseSeg())
# 程序初始运行对类的装饰, 先于 __main__/main(), 但晚于注册器 MODELS 的建立.
# 所以在程序初始部分, 就以完成类的注册了, 待调用 main() 时, 就能顺利利用注册器将字符串转换为类
@MODELS.register_module()
class BaseSeg(nn.Module):
def __init__(self,
encoder_args=None,
decoder_args=None,
cls_args=None,
**kwargs):
super().__init__()
调试过程中, 跟踪查看 MODELS._module_dict
可以发现已经注册了好多类.
MODELS._module_dict = {
'PointNetEncoder': <class 'openpoints.models.backbone.pointnet.PointNetEncoder'>,
'PointPatchEmbed': <class 'openpoints.models.layers.group_embed.PointPatchEmbed'>,
'P3Embed': <class 'openpoints.models.layers.group_embed.P3Embed'>,
'PointNet2Encoder': <class 'openpoints.models.backbone.pointnetv2.PointNet2Encoder'>,
'PointNet2Decoder': <class 'openpoints.models.backbone.pointnetv2.PointNet2Decoder'>,
'PointNet2PartDecoder': <class 'openpoints.models.backbone.pointnetv2.PointNet2PartDecoder'>,
'PointNextEncoder': <class 'openpoints.models.backbone.pointnext.PointNextEncoder'>,
'PointNextDecoder': <class 'openpoints.models.backbone.pointnext.PointNextDecoder'>,
'PointNextPartDecoder': <class 'openpoints.models.backbone.pointnext.PointNextPartDecoder'>,
'DGCNN': <class 'openpoints.models.backbone.dgcnn.DGCNN'>,
'DeepGCN': <class 'openpoints.models.backbone.deepgcn.DeepGCN'>,
'PointMLPEncoder': <class 'openpoints.models.backbone.pointmlp.PointMLPEncoder'>,
'PointMLP': <class 'openpoints.models.backbone.pointmlp.PointMLP'>,
'PointViT': <class 'openpoints.models.backbone.pointvit.PointViT'>,
'PointViTDecoder': <class 'openpoints.models.backbone.pointvit.PointViTDecoder'>,
'PointViTPartDecoder': <class 'openpoints.models.backbone.pointvit.PointViTPartDecoder'>,
'InvPointViT': <class 'openpoints.models.backbone.pointvit_inv.InvPointViT'>,
'InvPointViTDecoder': <class 'openpoints.models.backbone.pointvit_inv.InvPointViTDecoder'>,
'InvPointViTPartDecoder': <class 'openpoints.models.backbone.pointvit_inv.InvPointViTPartDecoder'>,
'CurveNet': <class 'openpoints.models.backbone.curvenet.CurveNet'>,
'MVFC': <class 'openpoints.models.backbone.simpleview.MVFC'>,
'MVModel': <class 'openpoints.models.backbone.simpleview.MVModel'>,
'BaseSeg': <class 'openpoints.models.segmentation.base_seg.BaseSeg'>,
'BasePartSeg': <class 'openpoints.models.segmentation.base_seg.BasePartSeg'>,
'VariableSeg': <class 'openpoints.models.segmentation.base_seg.VariableSeg'>,
'SegHead': <class 'openpoints.models.segmentation.base_seg.SegHead'>,
'VariableSegHead': <class 'openpoints.models.segmentation.base_seg.VariableSegHead'>,
'MultiSegHead': <class 'openpoints.models.segmentation.base_seg.MultiSegHead'>,
'BaseCls': <class 'openpoints.models.classification.cls_base.BaseCls'>,
'DistillCls': <class 'openpoints.models.classification.cls_base.DistillCls'>,
'ClsHead': <class 'openpoints.models.classification.cls_base.ClsHead'>,
'MaskedTransformerDecoder': <class 'openpoints.models.reconstruction.base_recontruct.MaskedTransformerDecoder'>,
'FoldingNet': <class 'openpoints.models.reconstruction.base_recontruct.FoldingNet'>,
'NodeShuffle': <class 'openpoints.models.reconstruction.base_recontruct.NodeShuffle'>,
'MaskedPointViT': <class 'openpoints.models.reconstruction.maskedpointvit.MaskedPointViT'>,
'MaskedPoint': <class 'openpoints.models.reconstruction.maskedpoint.MaskedPoint'>,
'MaskedPointGroup': <class 'openpoints.models.reconstruction.maskedpointgroup.MaskedPointGroup'>
}
3. 注册应用
有了注册器 MODELS, 并向其注册了各个类, 那么就可以应用其由字符串映射为类的功能, 方便地从 .yaml
文件配置实现类实例的创建.
初略时序如下图所示:
其中由 .yaml
文件读取获得的配置字典变量 cfg
中存在 NAME 条目, 通过 registry.get(*)
就能获得 NAME 字符串对应的已经注册了的类. 获得了对应的类后, NAME 条目完成使命, 剩下的其他配置条目将被用于 PointNeXT 中具体的深度神将网络模块/类的自动化配置构造 (这篇博文不涉及).
细节注释参看类 Registry 的方法 build_from_cfg, 重复如下:
def build_from_cfg(cfg, registry, default_args=None):
"""Build a module from config dict.
Args:
cfg (edict): Config dict. It should at least contain the key "NAME".
registry (:obj:`Registry`): The registry to search the type from.
Returns:
object: The constructed object.
"""
if not isinstance(cfg, dict):
raise TypeError(f'cfg must be a dict, but got {type(cfg)}')
if 'NAME' not in cfg:
if default_args is None or 'NAME' not in default_args:
raise KeyError(
'`cfg` or `default_args` must contain the key "NAME", '
f'but got {cfg}\n{default_args}')
if not isinstance(registry, Registry):
raise TypeError('registry must be an mmcv.Registry object, '
f'but got {type(registry)}')
if not (isinstance(default_args, dict) or default_args is None):
raise TypeError('default_args must be a dict or None, '
f'but got {type(default_args)}')
# if default_args is not None:
# cfg = config.merge_new_config(cfg, default_args)
obj_type = cfg.get('NAME') # 'BaseSeg'
if isinstance(obj_type, str):
obj_cls = registry.get(obj_type)
# <class 'openpoints.models.segmentation.base_seg.BaseSeg'>
# 按照名字字符串 从 self._module_dict 找出对应的 类/模块
# 实现从字符串到类的映射
if obj_cls is None:
raise KeyError(
f'{obj_type} is not in the {registry.name} registry')
elif inspect.isclass(obj_type):
obj_cls = obj_type
else:
raise TypeError(
f'type must be a str or valid type, but got {type(obj_type)}')
try:
obj_cfg = copy.deepcopy(cfg)
if default_args is not None:
obj_cfg.update(default_args)
obj_cfg.pop('NAME')
# 删除 "NAME" 项, obj_cfg 中留下除了 "NAME" 项的其他项
# 'NAME' 已完成对类 obj_cls 的映射
return obj_cls(**obj_cfg)
# 把变量都展开, 为 openpoints.models.segmentation.base_seg.BaseSeg(**obj_cfg)
# 又由于 BaseSeg 加了装饰器 @MODELS.register_module()
# 相当于调用 MODELS.register_module(module=BaseSeg(**obj_cfg))
# 其实已经在程序开头注册过了, 所以注册部分在此就没什么作用了
# 现在开始执行 openpoints.models.segmentation.base_seg.BaseSeg(**obj_cfg)
II. 参数解析
注册机制需要字符串参数的传入以构建类实例. 而参数的获得需要借助于解析过程将 .yaml
文件中的配置读入程序中.
1. 命令行解析
主程序部分先要将相关的 .yaml
文件读入并更新到 cfg 字典变量中, 注释如下.
if __name__ == "__main__":
parser = argparse.ArgumentParser('Scene segmentation training/testing')
# 创建解析器
parser.add_argument('--cfg', type=str, required=True, help='config file')
parser.add_argument('--profile', action='store_true', default=False, help='set to True to profile speed')
# 添加参数
args, opts = parser.parse_known_args()
# CUDA_VISIBLE_DEVICES=0,1 python examples/segmentation/main.py --cfg cfgs/s3dis/pointnext-s.yaml mode=train
# 其中 CUDA_VISIBLE_DEVICES=0,1 为环境变量, 不由 parser 解析
# args = Namespace(cfg='cfgs/s3dis/pointnext-s.yaml', profile=False)
# opts = ['mode=train']
cfg = EasyConfig()
cfg.load(args.cfg, recursive=True) # args.cfg = cfs/s3dis/pointnext-s.yaml
cfg.update(opts)
# overwrite the default arguments in yml
# mode = train 更新入 cfg 字典
if cfg.seed is None:
cfg.seed = np.random.randint(1, 10000)
# init distributed env first, since logger depends on the dist info.
cfg.rank, cfg.world_size, cfg.distributed, cfg.mp = dist_utils.get_dist_info(cfg)
cfg.sync_bn = cfg.world_size > 1 # debug 时, 只能单块 GPU; 正常运行时, 可以多块并行
# init log dir
cfg.task_name = args.cfg.split('.')[-2].split('/')[-2]
# task/dataset name, \eg s3dis, modelnet40_cls
# args.cfg = 'cfgs/s3dis/pointnext-s.yaml'
# args.cfg.split('.')[-2] = 'cfgs/s3dis/pointnext-s'
# args.cfg.split('.')[-2].split('/')[-2] = 's3dis'
cfg.cfg_basename = args.cfg.split('.')[-2].split('/')[-1]
# cfg_basename, \eg pointnext-xl\
# args.cfg.split('.')[-2].split('/')[-1] = 'pointnext-s'
tags = [
cfg.task_name, # task name (the folder of name under ./cfgs
cfg.mode,
cfg.cfg_basename, # cfg file name
f'ngpus{cfg.world_size}',
]
# tags = ['s3dis', 'train', 'pointnext-s', 'ngpus1']
opt_list = [] # for checking experiment configs from logging file
for i, opt in enumerate(opts):
if 'rank' not in opt and 'dir' not in opt and 'root' not in opt and 'pretrain' not in opt and 'path' not in opt and 'wandb' not in opt and '/' not in opt:
opt_list.append(opt)
cfg.root_dir = os.path.join(cfg.root_dir, cfg.task_name)
cfg.opts = '-'.join(opt_list) # 使用'-'作分隔符来进行join
cfg.is_training = cfg.mode not in ['test', 'testing', 'val', 'eval', 'evaluation']
if cfg.mode in ['resume', 'val', 'test']:
resume_exp_directory(cfg, pretrained_path=cfg.pretrained_path)
# 需要命令行 加 pretrained_path=XXX
cfg.wandb.tags = [cfg.mode]
else:
generate_exp_directory(cfg, tags, additional_id=os.environ.get('MASTER_PORT', None))
cfg.wandb.tags = tags
os.environ["JOB_LOG_DIR"] = cfg.log_dir
cfg_path = os.path.join(cfg.run_dir, "cfg.yaml")
# cfg_path = 'log/s3dis/s3dis-train-pointnext-s-ngpus1-20240730-092203-hQtDgCBNbQaYpAMwLVn9TC/cfg.yaml'
with open(cfg_path, 'w') as f:
yaml.dump(cfg, f, indent=2) # cfg 写入 f 文件
os.system('cp %s %s' % (args.cfg, cfg.run_dir))
# args.cfg = 'cfgs/s3dis/pointnext-s.yaml'
# cfg.run_dir = 'log/s3dis/s3dis-train-pointnext-s-ngpus1-20240730-092203-hQtDgCBNbQaYpAMwLVn9TC'
cfg.cfg_path = cfg_path
# wandb config
cfg.wandb.name = cfg.run_name
# cfg.run_name = 's3dis-train-pointnext-s-ngpus1-20240730-092203-hQtDgCBNbQaYpAMwLVn9TC'
# multi processing.
if cfg.mp:
port = find_free_port()
cfg.dist_url = f"tcp://localhost:{port}"
print('using mp spawn for distributed training')
mp.spawn(main, nprocs=cfg.world_size, args=(cfg,))
else:
main(0, cfg)
2. 参数加载更新
配置条目的读入和更新在类 EasyConfig 中实现, 部分注释如下.
class EasyConfig(dict):
def __getattr__(self, key: str) -> Any:
if key not in self:
raise AttributeError(key)
return self[key]
def __setattr__(self, key: str, value: Any) -> None:
self[key] = value
def __delattr__(self, key: str) -> None:
del self[key]
def load(self, fpath: str, *, recursive: bool = False) -> None:
"""load cfg from yaml
Args:
fpath (str): path to the yaml file
recursive (bool, optional): recursily load its parent defaul yaml files. Defaults to False.
"""
if not os.path.exists(fpath):
raise FileNotFoundError(fpath)
fpaths = [fpath]
# 'cfgs/s3dis/pointnext-s.yaml'
if recursive: # True
extension = os.path.splitext(fpath)[1] # .yaml
while os.path.dirname(fpath) != fpath: # 如果 fpath 是文件路径
fpath = os.path.dirname(fpath) # 去掉文件名, 返回目录, 每次脱去一级
fpaths.append(os.path.join(fpath, 'default' + extension))
# fpaths =['cfgs/s3dis/pointnext-s.yaml',
# 'cfgs/s3dis/default.yaml',
# 'cfgs/default.yaml',
# 'default.yaml']
for fpath in reversed(fpaths): # 反转迭代器
if os.path.exists(fpath):
with open(fpath) as f:
self.update(yaml.safe_load(f))
# 把 fpaths 中的所有 .yaml 文件中的配置条目写在一个 dict 变量中
def reload(self, fpath: str, *, recursive: bool = False) -> None:
self.clear()
self.load(fpath, recursive=recursive)
# mutimethod makes python supports function overloading
@multimethod
def update(self, other: Dict) -> None:
# .yaml items 转为 dict 变量中的 key:value 对
for key, value in other.items():
if isinstance(value, dict):
if key not in self or not isinstance(self[key], EasyConfig):
# 子条目
self[key] = EasyConfig()
# recursively update
self[key].update(value)
else:
self[key] = value
@multimethod
def update(self, opts: Union[List, Tuple]) -> None:
index = 0
while index < len(opts):
opt = opts[index]
if opt.startswith('--'):
opt = opt[2:]
if '=' in opt:
key, value = opt.split('=', 1)
index += 1
else:
key, value = opt, opts[index + 1]
index += 2
current = self
subkeys = key.split('.')
try:
value = literal_eval(value)
except:
pass
for subkey in subkeys[:-1]:
current = current.setdefault(subkey, EasyConfig())
current[subkeys[-1]] = value
def dict(self) -> Dict[str, Any]:
configs = dict()
for key, value in self.items():
if isinstance(value, EasyConfig):
value = value.dict()
configs[key] = value
return configs
def hash(self) -> str:
buffer = json.dumps(self.dict(), sort_keys=True)
return hashlib.sha256(buffer.encode()).hexdigest()
def __str__(self) -> str:
texts = []
for key, value in self.items():
if isinstance(value, EasyConfig):
seperator = '\n'
else:
seperator = ' '
text = key + ':' + seperator + str(value)
lines = text.split('\n')
for k, line in enumerate(lines[1:]):
lines[k + 1] = (' ' * 2) + line
texts.extend(lines)
return '\n'.join(texts)
3. 获得的参数
将 fpaths = ['cfgs/s3dis/pointnext-s.yaml', 'cfgs/s3dis/default.yaml', 'cfgs/default.yaml', 'default.yaml']
所含全部 .yaml 文件 (如存在, 其中 default.yaml 不存在) 内的所有条目解析并写入 cfg 字典变量.
参数解析后得到的字典变量 cfg 如下, 其中 cfg.model 部分将被用于网络模型 (类实现) 的自动化配置与构建.
dist_url: tcp://localhost:8888
dist_backend: nccl
multiprocessing_distributed: False
ngpus_per_node: 1
world_size: 1
launcher: mp
local_rank: 0
use_gpu: True
seed: 3392
epoch: 0
epochs: 100
ignore_index: None
val_fn: validate
deterministic: False
sync_bn: False
criterion_args:
NAME: CrossEntropy
label_smoothing: 0.2
use_mask: False
grad_norm_clip: 10
layer_decay: 0
step_per_update: 1
start_epoch: 1
sched_on_epoch: True
wandb:
use_wandb: False
project: PointNeXt-S3DIS
tags: ['s3dis', 'train', 'pointnext-s', 'ngpus1']
name: s3dis-train-pointnext-s-ngpus1-20240802-170230-MzNANkH27yLSBYpmpmF5tu
use_amp: False
use_voting: False
val_freq: 1
resume: False
test: False
finetune: False
mode: train
logname: None
load_path: None
print_freq: 50
save_freq: -1
root_dir: log/s3dis
pretrained_path: None
datatransforms:
train: ['ChromaticAutoContrast', 'PointsToTensor', 'PointCloudScaling', 'PointCloudXYZAlign', 'PointCloudJitter', 'ChromaticDropGPU', 'ChromaticNormalize']
val: ['PointsToTensor', 'PointCloudXYZAlign', 'ChromaticNormalize']
vote: ['ChromaticDropGPU']
kwargs:
color_drop: 0.2
gravity_dim: 2
scale: [0.9, 1.1]
angle: [0, 0, 1]
jitter_sigma: 0.005
jitter_clip: 0.02
feature_keys: x,heights
dataset:
common:
NAME: S3DIS
data_root: data/S3DIS/s3disfull
test_area: 5
voxel_size: 0.04
train:
split: train
voxel_max: 24000
loop: 30
presample: False
val:
split: val
voxel_max: None
presample: True
test:
split: test
voxel_max: None
presample: False
num_classes: 13
batch_size: 32
val_batch_size: 1
dataloader:
num_workers: 6
cls_weighed_loss: False
optimizer:
NAME: adamw
weight_decay: 0.0001
sched: cosine
warmup_epochs: 0
min_lr: 1e-05
lr: 0.01
log_dir: log/s3dis/s3dis-train-pointnext-s-ngpus1-20240802-170230-MzNANkH27yLSBYpmpmF5tu
model:
NAME: BaseSeg
encoder_args:
NAME: PointNextEncoder
blocks: [1, 1, 1, 1, 1]
strides: [1, 4, 4, 4, 4]
sa_layers: 2
sa_use_res: True
width: 32
in_channels: 4
expansion: 4
radius: 0.1
nsample: 32
aggr_args:
feature_type: dp_fj
reduction: max
group_args:
NAME: ballquery
normalize_dp: True
conv_args:
order: conv-norm-act
act_args:
act: relu
norm_args:
norm: bn
decoder_args:
NAME: PointNextDecoder
cls_args:
NAME: SegHead
num_classes: 13
in_channels: None
norm_args:
norm: bn
in_channels: 4
rank: 0
distributed: False
mp: False
task_name: s3dis
cfg_basename: pointnext-s
opts: mode=train
is_training: True
run_name: s3dis-train-pointnext-s-ngpus1-20240802-170230-MzNANkH27yLSBYpmpmF5tu
run_dir: log/s3dis/s3dis-train-pointnext-s-ngpus1-20240802-170230-MzNANkH27yLSBYpmpmF5tu
exp_dir: log/s3dis/s3dis-train-pointnext-s-ngpus1-20240802-170230-MzNANkH27yLSBYpmpmF5tu
ckpt_dir: log/s3dis/s3dis-train-pointnext-s-ngpus1-20240802-170230-MzNANkH27yLSBYpmpmF5tu/checkpoint
log_path: log/s3dis/s3dis-train-pointnext-s-ngpus1-20240802-170230-MzNANkH27yLSBYpmpmF5tu/s3dis-train-pointnext-s-ngpus1-20240802-170230-MzNANkH27yLSBYpmpmF5tu.log
cfg_path: log/s3dis/s3dis-train-pointnext-s-ngpus1-20240802-170230-MzNANkH27yLSBYpmpmF5tu/cfg.yaml
III. 总结
1. 结果
在 examples/segmentation/main.py
的 main()
函数中建立深度网络模型 (类实现) 的部分代码:
if cfg.model.get('in_channels', None) is None:
cfg.model.in_channels = cfg.model.encoder_args.in_channels # 4
model = build_model_from_cfg(cfg.model).to(cfg.rank)
model_size = cal_model_parm_nums(model)
logging.info(model)
logging.info('Number of params: %.4f M' % (model_size / 1e6))
通过 build_model_from_cfg(cfg.model)
调用, 进而执行 openpoints.models.segmentation.base_seg.BaseSeg(**obj_cfg)
, 获得网络模型结构:
BaseSeg(
(encoder): PointNextEncoder(
(encoder): Sequential(
(0): Sequential(
(0): SetAbstraction(
(convs): Sequential(
(0): Sequential(
(0): Conv1d(4, 32, kernel_size=(1,), stride=(1,))
)
)
)
)
(1): Sequential(
(0): SetAbstraction(
(skipconv): Sequential(
(0): Conv1d(32, 64, kernel_size=(1,), stride=(1,))
)
(act): ReLU(inplace=True)
(convs): Sequential(
(0): Sequential(
(0): Conv2d(35, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
)
(1): Sequential(
(0): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(grouper): QueryAndGroup()
)
)
(2): Sequential(
(0): SetAbstraction(
(skipconv): Sequential(
(0): Conv1d(64, 128, kernel_size=(1,), stride=(1,))
)
(act): ReLU(inplace=True)
(convs): Sequential(
(0): Sequential(
(0): Conv2d(67, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
)
(1): Sequential(
(0): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(grouper): QueryAndGroup()
)
)
(3): Sequential(
(0): SetAbstraction(
(skipconv): Sequential(
(0): Conv1d(128, 256, kernel_size=(1,), stride=(1,))
)
(act): ReLU(inplace=True)
(convs): Sequential(
(0): Sequential(
(0): Conv2d(131, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
)
(1): Sequential(
(0): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(grouper): QueryAndGroup()
)
)
(4): Sequential(
(0): SetAbstraction(
(skipconv): Sequential(
(0): Conv1d(256, 512, kernel_size=(1,), stride=(1,))
)
(act): ReLU(inplace=True)
(convs): Sequential(
(0): Sequential(
(0): Conv2d(259, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
)
(1): Sequential(
(0): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(grouper): QueryAndGroup()
)
)
)
)
(decoder): PointNextDecoder(
(decoder): Sequential(
(0): Sequential(
(0): FeaturePropogation(
(convs): Sequential(
(0): Sequential(
(0): Conv1d(96, 32, kernel_size=(1,), stride=(1,), bias=False)
(1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
)
(1): Sequential(
(0): Conv1d(32, 32, kernel_size=(1,), stride=(1,), bias=False)
(1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
)
)
)
)
(1): Sequential(
(0): FeaturePropogation(
(convs): Sequential(
(0): Sequential(
(0): Conv1d(192, 64, kernel_size=(1,), stride=(1,), bias=False)
(1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
)
(1): Sequential(
(0): Conv1d(64, 64, kernel_size=(1,), stride=(1,), bias=False)
(1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
)
)
)
)
(2): Sequential(
(0): FeaturePropogation(
(convs): Sequential(
(0): Sequential(
(0): Conv1d(384, 128, kernel_size=(1,), stride=(1,), bias=False)
(1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
)
(1): Sequential(
(0): Conv1d(128, 128, kernel_size=(1,), stride=(1,), bias=False)
(1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
)
)
)
)
(3): Sequential(
(0): FeaturePropogation(
(convs): Sequential(
(0): Sequential(
(0): Conv1d(768, 256, kernel_size=(1,), stride=(1,), bias=False)
(1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
)
(1): Sequential(
(0): Conv1d(256, 256, kernel_size=(1,), stride=(1,), bias=False)
(1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
)
)
)
)
)
)
(head): SegHead(
(head): Sequential(
(0): Sequential(
(0): Conv1d(32, 32, kernel_size=(1,), stride=(1,), bias=False)
(1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
)
(1): Dropout(p=0.5, inplace=False)
(2): Sequential(
(0): Conv1d(32, 13, kernel_size=(1,), stride=(1,))
)
)
)
)
2. Todo
以上网络结构如何自动化地配置与构造? 待阅读源码学习和理解.
感谢论文和代码作者开源研究成果 !