【mmsegmentation】Head模块(进阶)自定义自己的HEAD

1、定义自己的head

driving\models\dense_heads\shuai_head.py

import torch
from torch import nn
from collections import namedtuple
from mmengine.model import BaseModule
from mmseg.models import HEADS, build_head, build_loss

import sys
sys.path.append("D:/BaiduSyncdisk/SHUAI/")
from models.losses.shuai_loss import * # 注册 ShuaiLoss

@HEADS.register_module()
class ShuaiHead(BaseModule):
    def __init__(
        self,
        loss,
        task='RoadCls',
        meta_info=None,
        aux_annotation=None,
        source=None,

        num_classes=4,
        momentum = 0.01,
        epsilon = 1e-3,
        in_channels = 448,
        out_channels = 1280,
    ):
        super().__init__()
        print(" ShuaiHead __init__")
        self.task = task,
        self.loss = build_loss(loss)
        self.meta_info = meta_info,
    
        self.roadcls_head = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(num_features=out_channels, momentum=momentum, eps=epsilon),
            nn.ReLU6(inplace=True),
        )
        self.avgpool = torch.nn.AdaptiveAvgPool2d((1, 1))
        dropout_rate = 0.3
        if dropout_rate > 0:
            self.dropout = nn.Dropout(dropout_rate)
        else:
            self.dropout = None
        self.fc = torch.nn.Linear(out_channels, num_classes)

        # [1 448 7 7]
    def forward(self, x):
        outputs = namedtuple("outputs","roadcls_pred")
        # [1,1280,7,7]
        x = self.roadcls_head(x)
        # [1,1280,1,1]
        x = self.avgpool(x)
        # [1,1280]
        x = x.view(x.size(0), -1)
        if self.dropout is not None:
            x = self.dropout(x)
        # [1,4]
        x = self.fc(x)

        pred_dict = outputs(roadcls_pred=x)
        print("ShuaiHead foward:",pred_dict)
        return pred_dict
    
    
    def forward_train(self,
                       head_args):
        """
        Forward call along with loss computation/
        """
        input , img_metas, annotations, train_cfg= head_args.values()
        target, is_annotation_present, sample_weight = annotations.values()

        device_id = input.device
        sample_ratio = 1.0

        pred_dict = self.forward(input)
        x = pred_dict.roadcls_pred

        train_loss = self.loss(x,
                          target,
                          device_id = device_id,
                          sample_ratio=sample_ratio)
        print("ShuaiHead foward:",train_loss)
        return train_loss

看下HEADS注册表(@HEADS.register_module())
在这里插入图片描述

  • 可以看到ShuaiHead可以被注册到HEADS
  • 其实,这里的HEADS是BACKBONES NECKS HEADS LOSSES SEGMENTORS的总和
from mmseg.registry import MODELS

BACKBONES = MODELS
NECKS = MODELS
HEADS = MODELS
LOSSES = MODELS
SEGMENTORS = MODELS
  • 看下这的BaseModule,mmengine\model\base_module.py

在这里插入图片描述

2、调用Shuai_head

if __name__ == "__main__":
    print("call shuai_head:")
    # 1.配置 dict
    num_classes = 4
    shuai_loss = dict(type='ShuaiLoss',loss_weight=1.0,loss_name='loss_shuai')
    head = dict(loss=shuai_loss,type='ShuaiHead',num_classes=num_classes)
    # 从注册器中构建
    shuai_head = build_head(head)

    # 使用shuai head
    # 前向传播
    input = torch.Tensor(2,448,7,7)   # [B,C,H,W]
    output = shuai_head(input)

    # 前向传播 + Loss计算
    target = torch.Tensor(2,num_classes)
    annotations = {"targets":target, "is_annotation_present":True, "sample_weight":1}
    head_args = {"input":input,"img_metas":None, "annotations":annotations,"train_cfg":None}
    loss = shuai_head.forward_train(head_args)

在这里插入图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

BILLY BILLY

你的奖励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值