用mindspore写了一个decouplehead模块,在yolov5上使用时报错

问题描述:

1.平台:gpu

2.版本:1.2.0

decouplehead模块代码如下:

import mindspore as ms

import mindspore.nn as nn

from mindspore.common.tensor import Tensor

from mindspore import context

from mindspore.context import ParallelMode

from mindspore.parallel._auto_parallel_context import auto_parallel_context

from mindspore.communication.management import get_group_size

from mindspore.ops import operations as P

from mindspore.ops import functional as F

from mindspore.ops import composite as C



class BaseConv(nn.Cell):

    """A Conv2d -> Batchnorm -> silu/leaky relu block"""



    def __init__(

        self, in_channels, out_channels, ksize, stride, bias=True, act="leaky"

    ):

        super(BaseConv,self).__init__()

        # same padding

        # pad = (ksize - 1) // 2

        self.conv = nn.Conv2d(

            in_channels,

            out_channels,

            kernel_size=ksize,

            stride=stride,

            # padding=pad, #默认 same padding

            has_bias=bias

        )

        self.bn = nn.BatchNorm2d(out_channels)

        self.act = nn.LeakyReLU(0.1)



    def construct(self, x):

        return self.act(self.bn(self.conv(x)))



class DecoupleHead(nn.Cell):

    def __init__(

        self,

        num_classes,

        width=1.0,

        strides=[8, 16, 32],

        in_channels=[128, 256, 512],

        act="leaky",

    ):

        """

        Args:

            act (str): activation type of conv. Defalut value: "silu".

            depthwise (bool): whether apply depthwise conv in conv branch. Defalut value: False.

        """

        super(DecoupleHead,self).__init__()



        self.n_anchors = 3

        self.num_classes = num_classes



        self.cls_convs = nn.CellList()

        self.reg_convs = nn.CellList()



        self.cls_preds = nn.CellList()

        self.reg_preds = nn.CellList()

        self.obj_preds = nn.CellList()



        self.stems = nn.CellList()

        self.concat = P.Concat(1)

        for i in range(len(in_channels)):

            self.stems.append(

                BaseConv(

                    in_channels=int(in_channels[i] * width),

                    out_channels=int(256 * width),

                    ksize=1,

                    stride=1,

                    act=act,

                )

            )

            self.cls_convs.append(

                nn.SequentialCell(

                    *[

                        BaseConv(

                            in_channels=int(256 * width),

                            out_channels=int(256 * width),

                            ksize=3,

                            stride=1,

                            act=act,

                        ),

                        BaseConv(

                            in_channels=int(256 * width),

                            out_channels=int(256 * width),

                            ksize=3,

                            stride=1,

                            act=act,

                        ),

                    ]

                )

            )

            self.reg_convs.append(

                nn.SequentialCell(

                    *[

                        BaseConv(

                            in_channels=int(256 * width),

                            out_channels=int(256 * width),

                            ksize=3,

                            stride=1,

                            act=act,

                        ),

                        BaseConv(

                            in_channels=int(256 * width),

                            out_channels=int(256 * width),

                            ksize=3,

                            stride=1,

                            act=act,

                        ),

                    ]

                )

            )

            # class

            self.cls_preds.append(

                nn.Conv2d(

                    in_channels=int(256 * width),

                    out_channels=self.n_anchors * self.num_classes,

                    kernel_size=1,

                    stride=1,

                    padding=0,

                )

            )

            # x,y,w,h

            self.reg_preds.append(

                nn.Conv2d(

                    in_channels=int(256 * width),

                    out_channels=self.n_anchors * 4,

                    kernel_size=1,

                    stride=1,

                    padding=0,

                )

            )

            #confidence

            self.obj_preds.append(

                nn.Conv2d(

                    in_channels=int(256 * width),

                    out_channels=self.n_anchors * 1,

                    kernel_size=1,

                    stride=1,

                    padding=0,

                )

            )

        self.strides = strides



    def construct(self, xin):

        outputs = []

        for k, (cls_conv, reg_conv, x) in enumerate(

            zip(self.cls_convs, self.reg_convs, xin)

        ):

            x = self.stems[k](x)

            cls_x = x

            reg_x = x

            # print("ffffffffffffffffffffffffffffff")

            # print(reg_x)

            # print("jjjjjjjjjjjjjjjjjjjjjj")

            cls_feat = cls_conv(cls_x)

            cls_output = self.cls_preds[k](cls_feat)



            reg_feat = reg_conv(reg_x)

           

            reg_output = self.reg_preds[k](reg_feat)



            obj_output = self.obj_preds[k](reg_feat)



            output = self.concat((reg_output, obj_output, cls_output))

            outputs.append(output)



        return outputs

在yolo中调用的代码如下:

class YOLO(nn.Cell):

    def __init__(self, backbone, shape):

        super(YOLO, self).__init__()

        self.backbone = backbone

        self.config = ConfigYOLOV5()

        #neck , fpn pan

        self.conv1 = Conv(shape[5], shape[4], k=1, s=1)

        self.CSP5 = BottleneckCSP(shape[5], shape[4], n=1*shape[6], shortcut=False)

        self.conv2 = Conv(shape[4], shape[3], k=1, s=1)

        self.CSP6 = BottleneckCSP(shape[4], shape[3], n=1*shape[6], shortcut=False)

        self.conv3 = Conv(shape[3], shape[3], k=3, s=2)

        self.CSP7 = BottleneckCSP(shape[4], shape[4], n=1*shape[6], shortcut=False)

        self.conv4 = Conv(shape[4], shape[4], k=3, s=2)

        self.CSP8 = BottleneckCSP(shape[5], shape[5], n=1*shape[6], shortcut=False)

        # yolovx head



        input_channels = [shape[3],shape[4],shape[5]]

        self.decouple_head = DecoupleHead(num_classes= self.config.num_classes, in_channels=input_channels)



        self.concat = P.Concat(axis=1)



    def construct(self, x):

        """

        input_shape of x is (batch_size, 3, h, w)

        feature_map1 is (batch_size, backbone_shape[2], h/8, w/8)

        feature_map2 is (batch_size, backbone_shape[3], h/16, w/16)

        feature_map3 is (batch_size, backbone_shape[4], h/32, w/32)

        """

        img_height = P.Shape()(x)[2] * 2

        img_width = P.Shape()(x)[3] * 2



        feature_map1, feature_map2, feature_map3 = self.backbone(x)



        c1 = self.conv1(feature_map3)

        ups1 = P.ResizeNearestNeighbor((img_height // 16, img_width // 16))(c1)

        c2 = self.concat((ups1, feature_map2))

        c3 = self.CSP5(c2)

        c4 = self.conv2(c3)

        ups2 = P.ResizeNearestNeighbor((img_height // 8, img_width // 8))(c4)

        c5 = self.concat((ups2, feature_map1))

        # out

        c6 = self.CSP6(c5)

        c7 = self.conv3(c6)



        c8 = self.concat((c7, c4))

        # out

        c9 = self.CSP7(c8)

        c10 = self.conv4(c9)

        c11 = self.concat((c10, c1))

        # out

        c12 = self.CSP8(c11)

        # print("dddddddddddddd")

        # print(c6)

        small_object_output, medium_object_output, big_object_output=self.decouple_head([c6,c9,c12])



        return small_object_output, medium_object_output, big_object_output

使用的是静态图方式,在使用yolov5自带的head时运行没有问题,替换为这个decouplehead就出错了,结果下图中的错误,现在不知道如何定位,麻烦帮忙看一下

看报错信息是说少了一个参数,但是实在找不到在哪里出现的这个问题

【截图信息】

解答:

如果动态图可以,静态图不行的话,我怀疑

 

这块可能静态图的支持不够好,可以尝试用

for k in range(len(self.reg)):

    cls_conv=self.cls[k]

    reg_conv = convs[k]

    x = xin[k]

这种形式写

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值