1. 定义一个自己的类
在csnmemde.py中,导入mmaction.models.builder 中的HEADS,HEADS注册器写在class上面。在其中编写自己需要的类,其中需要有init_weights(),这个方法里的内容可以自己定义。
还定义了一个mmdet_imported用来最后的一步register_module() 。当被引用时,就将该类注册到全局。
# csnmemde.py
from mmaction.models.builder import HEADS
try:
from mmdet.models import BACKBONES as MMDET_BACKBONES # 定义backbone时用到这句
from mmdet.models.builder import SHARED_HEADS as MMDET_SHARED_HEADS
mmdet_imported = True
except (ImportError, ModuleNotFoundError):
mmdet_imported = False
@HEADS.register_module()
class ResNetCSNMem(nn.Module):
def __init__(self, chnum_in, mem_dim, feature_num,
feature_num_2, feature_num_x2, feature_num_x4,
feature_num_x6, feature_num_x8, shrink_thres=0.0025):
super(ResNetCSNMem, self).__init__()
print('ResNetCSNCov3DMem')
self.chnum_in = chnum_in # 通道数
self.feature_num = feature_num
self.feature_num_2 = feature_num_2
self.feature_num_x2 = feature_num_x2
self.feature_num_x4 = feature_num_x4
self.feature_num_x6 = feature_num_x6
self.feature_num_x8 = feature_num_x8
def init_weights(self):
pass
if mmdet_imported:
MMDET_SHARED_HEADS.register_module()(ResNetCSNMem)
2. 在该类所在的文件夹的__init__.py中添加ResNetCSNMem类
from .x3d_head import X3DHead
from .csnmemde import *
2.加载model
定义完一个新的类,在需要用到的地方,要import它(from models.head import *),而且要确保重新activate了对应的虚拟环境, 重新activate了对应的虚拟环境, 重新activate了对应的虚拟环境, 这样这个新类才会注册到mmaction全局。
from mmaction.models import build_head
from mmcv import Config
from models.head import *
cfg = Config.fromfile('config/csncfg.py')
memde = build_head(cfg.model.cls_head)
判断是否将新的类注册到全局,打印其HEADS看一下,注册成功!
from mmaction.models import HEADS
HEADS
该类存放的地址为编写这个类的地址路径
对于mmdet的注册方法一样,还可以自己定义BACKBONE, NECK等。这里完成了Registry的部分,如果用这个类,还需要定义config,使用时builder一下。