YOLOv8添加MobileViTv3模块

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档

一、导入MobileViTv3模块

添加官方的ViTv3代码。

class MbViTV3(MbViTBkV2):
    def __init__(
            self,
            in_channels: int,
            attn_unit_dim: int,
            patch_h: Optional[int] = 2,
            patch_w: Optional[int] = 2,
            ffn_multiplier: Optional[Union[Sequence[Union[int, float]], int, float]] = 2.0,
            n_attn_blocks: Optional[int] = 2,
            attn_dropout: Optional[float] = 0.0,
            dropout: Optional[float] = 0.0,
            ffn_dropout: Optional[float] = 0.0,
            conv_ksize: Optional[int] = 3,
            attn_norm_layer: Optional[str] = "layer_norm_2d",
            enable_coreml_compatible_fn: Optional[bool] = True,
    ) -> None:
        super(MbViTV3, self).__init__(in_channels, attn_unit_dim, patch_h, patch_w, ffn_multiplier, n_attn_blocks,
                                      attn_dropout, dropout, ffn_dropout, conv_ksize, attn_norm_layer)
        self.enable_coreml_compatible_fn = enable_coreml_compatible_fn
        if self.enable_coreml_compatible_fn:
            # we set persistent to false so that these weights are not part of model's state_dict
            self.register_buffer(
                name="unfolding_weights",
                tensor=self._compute_unfolding_weights(),
                persistent=False,
            )
        cnn_out_dim = attn_unit_dim
        self.conv_proj = Conv(2 * cnn_out_dim, in_channels, 1, 1, act=False)

    def forward_spatial(self, x: Tensor, *args, **kwargs) -> Tensor:
        x = self.resize_input_if_needed(x)

        fm_conv = self.local_rep(x)

        # convert feature map to patches
        if self.enable_coreml_compatible_fn:
            patches, output_size = self.unfolding_coreml(fm_conv)
        else:
            patches, output_size = self.unfolding_pytorch(fm_conv)

        # learn global representations on all patches
        patches = self.global_rep(patches)

        # [B x Patch x Patches x C] --> [B x C x Patches x Patch]
        if self.enable_coreml_compatible_fn:
            fm = self.folding_coreml(patches=patches, output_size=output_size)
        else:
            fm = self.folding_pytorch(patches=patches, output_size=output_size)

        # MobileViTv3: local+global instead of only global
        fm = self.conv_proj(torch.cat((fm, fm_conv), dim=1))

        # MobileViTv3: skip connection
        fm = fm + x

        return fm

二、在ultralytics/nn/tasks.py中导入

在方法parse_model中添加模块:

def parse_model(d, ch, verbose=True):  # model_dict, input_channels(3)
    # Parse a YOLO model.yaml dictionary
    if verbose:
        LOGGER.info(f"\n{'':>3}{'from':>20}{'n':>3}{'params':>10}  {'module':<45}{'arguments':<30}")
    nc, gd, gw, act = d['nc'], d['depth_multiple'], d['width_multiple'], d.get('activation')
    if act:
        Conv.default_act = eval(act)  # redefine default activation, i.e. Conv.default_act = nn.SiLU()
        if verbose:
            LOGGER.info(f"{colorstr('activation:')} {act}")  # print

    layers, save, c2 = [], [], ch[-1]  # layers, savelist, ch out
    for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']):  # from, number, module, args
    	.......
    	elif m in {MbViTV3}:
            c2 = args[0]
        .......

三、在yaml配置文件中写入

添加在15层后

# Parameters
nc: 10  # number of classes
depth_multiple: 0.33  # scales module repeats
width_multiple: 1.00  # scales convolution channels

# YOLOv8.0n backbone
backbone:
  # [from, repeats, module, args]
  - [-1, 1, Conv, [16, 3, 2]]  # 0-P1/2
  - [-1, 1, Conv, [32, 3, 2]]  # 1-P2/4
  - [-1, 3, C2f, [32, True]]
  - [-1, 1, Conv, [64, 3, 2]]  # 3-P3/8
  - [-1, 6, C2f, [64, True]]
  - [-1, 1, Conv, [128, 3, 2]]  # 5-P4/16
  - [-1, 6, C2f, [128, True]]
  - [-1, 1, Conv, [256, 3, 2]]  # 7-P5/32
  - [-1, 3, C2f, [256, True]]
  - [-1, 1, SPPF, [256, 5]]  # 9

# YOLOv8.0n head
head:
  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 6], 1, Concat, [1]]  # cat backbone P4
  - [-1, 3, C2f, [128]]  # 12

  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 4], 1, Concat, [1]]  # cat backbone P3
  - [-1, 3, C2f, [64]]  # 15 (P3/8-small)

  - [-1, 1, MbViTV3, [64, 32]]

  - [[16], 1, Detect, [nc]]  # Detect(P3, P4, P5)
  • 11
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值