mmflow——tutorial 4 添加新模型

在这里插入图片描述

添加新模型

mmflow将光流估计分解为编码器和解码器

添加一个新编码器

  • 创建新文件mmflow/models/encoders/my_model.py
from mmcv.runner import BaseModule

from ..builder import ENCODERS

@ENCODERS.register_module()
class MyModel(BaseModule):

    def __init__(self, arg1, arg2):
        pass

    def forward(self, x):  # should return a tuple
        pass

    def init_weights(self, pretrained=None):
        pass
  • 在 mmflow/models/encoders/init.py中引入模型
from .my_model import MyModel

添加新解码器

  • 创建文件mmflow/models/decoders/my_decoder.py
from ..builder import DECODERS


@DECODERS.register_module()
class MyDecoder(BaseModule):

    def __init__(self, arg1, arg2):
        pass

    def forward(self, *args):
        pass

    # optional
    def init_weights(self):
        pass

    def forward_train(self, *args, flow_gt):
        flow_pred = self.forward(*args)
        return self.losses(flow_pred, flow_gt)

    def forward_test(self,*args, img_metas):
        flow_pred = self.forward(*args)
        return self.get_flow(flow_pred, img_metas)
  • mmflow/models/decoders/__init__.py引入模型
from .my_decoder import MyDecoder

添加新光流估计器

  • 创建文件mmflow/models/flow_estimators/my_estimator.py
from ..builder import FLOW_ESTIMATORS
from .base import FlowEstimator


@FLOW_ESTIMATORS.register_module()
class MyEstimator(FlowEstimator):

    def __init__(self, arg1, arg2):
        pass

    def forward_train(self, imgs):
        pass

    def forward_test(self, imgs):
        pass
  • mmflow/models/flow_estimator/__init__.py
from .my_estimator import MyEstimator
  • 在配置文件中使用,将模型类型命名为MyEstimator
model = dict(
    type='MyEstimator',
    encoder=dict(
        type='MyModel',
        arg1=xxx,
        arg2=xxx),
    decoder=dict(
        type='MyDecoder',
        arg1=xxx,
        arg2=xxx))

添加新损失函数

如果要为光流估计器添加新的损失函数MyLoss,需要在mmflow/models/losses/my_loss.py实现

import torch
import torch.nn as nn

from mmflow.models import LOSSES

def my_loss(pred, target):
    pass

@LOSSES.register_module()
class MyLoss(nn.Module):

    def __init__(self, arg1):
        super(MyLoss, self).__init__()


    def forward(self, output, target):
        return my_loss(output, target)
  • 然后需要在mmflow/models/losses/__init__.py中添加损失函数
from .my_loss import MyLoss, my_loss
  • 最后,在模型的flow_loss中修改损失函数
flow_loss=dict(type='MyLoss', use_target_weight=False)
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值