mmflow-官方教程翻译05 Adding New Modules

原文链接:Tutorial 4: Adding New Modules — mmflow documentation

教程4 添加新的模块

MMFlow把一个光流估计方法flow_estimator拆分成了编码器encoder和解码器decoder。本教程展示如何添加新的部件。

添加新的编码器

1.创建一个新文件:mmflow/models/encoders/my_model.py

from mmcv.runner import BaseModule

from ..builder import ENCODERS

@ENCODERS.register_module()
class MyModel(BaseModule):

    def __init__(self, arg1, arg2):
        pass

    def forward(self, x):  # should return a tuple
        pass

    def init_weights(self, pretrained=None):
        pass

2.在mmflow/models/encoders/__init__.py导入该模块(MyModel)

from .my_model import MyModel

添加新的解码器

1.创建新文件:mmflow/models/decoders/my_decoder.py

你可以写一个新的从MMCV的BaseModule中继承的头,然后覆写forward(self, x), forward_train和forward_test等方法。我们在MMCV中有一个统一的接口用于权重初始化,你可以使用init_cfg用于指定初始化函数和参数,或者覆写init_weights()方法。(如果你更喜欢自定义初始化的话)

from ..builder import DECODERS


@DECODERS.register_module()
class MyDecoder(BaseModule):

    def __init__(self, arg1, arg2):
        pass

    def forward(self, *args):
        pass

    # optional
    def init_weights(self):
        pass

    def forward_train(self, *args, flow_gt):
        flow_pred = self.forward(*args)
        return self.losses(flow_pred, flow_gt)

    def forward_test(self,*args, img_metas):
        flow_pred = self.forward(*args)
        return self.get_flow(flow_pred, img_metas)

loss是用于计算模型输出和目标之间损失值的损失函数,get_flow在BaseDecoder中实现,用于把光流的形状尺寸恢复到原始输入图像的尺寸。

1.导入mmflow/models/decoders/__init__.py中的模块

from .my_decoder import MyDecoder

添加新的光流估计器

1.创建新文件:mmflow/models/flow_estimators/my_estimator.py

你可以写一个新的类似于PWC-Net的继承自FlowEstimator的头,然后实现forward_train和forward_test等方法。

from ..builder import FLOW_ESTIMATORS
from .base import FlowEstimator


@FLOW_ESTIMATORS.register_module()
class MyEstimator(FlowEstimator):

    def __init__(self, arg1, arg2):
        pass

    def forward_train(self, imgs):
        pass

    def forward_test(self, imgs):
        pass

2.导入mmflow/models/flow_estimator/__init__.py中的模块

from .my_estimator import MyEstimator

3.在你的配置文件中使用它

我们把模型类型设为MyEstimator

model = dict(
    type='MyEstimator',
    encoder=dict(
        type='MyModel',
        arg1=xxx,
        arg2=xxx),
    decoder=dict(
        type='MyDecoder',
        arg1=xxx,
        arg2=xxx))

添加新的损失函数

假设你想增加一个新的名为MyLoss的损失函数用于光流估计。为了添加新的损失函数,使用者需要在mmflow/models/losses/my_loss.py中编程实现。

import torch
import torch.nn as nn

from mmflow.models import LOSSES

def my_loss(pred, target):
    pass

@LOSSES.register_module()
class MyLoss(nn.Module):

    def __init__(self, arg1):
        super(MyLoss, self).__init__()


    def forward(self, output, target):
        return my_loss(output, target)

然后把它添加到mmflow/models/losses/__init__.py中

from .my_loss import MyLoss, my_loss

要想使用它,需要在模型的flow_loss字段中修改。

flow_loss=dict(type='MyLoss', use_target_weight=False)

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值