RESA 车道线检测模型-debug分析

车道线检测模型 RESA

该模型只有一个关键点就是resa模块,把这个想清楚就没什么了,下面看代码

class RESA(nn.Module):
    def __init__(self, cfg):
        super(RESA, self).__init__()
        # self.iter = cfg.resa.iter
        # chan = cfg.resa.input_channel
        # fea_stride = cfg.backbone.fea_stride
        # self.height = cfg.img_height // fea_stride
        # self.width = cfg.img_width // fea_stride
        # self.alpha = cfg.resa.alpha
        # conv_stride = cfg.resa.conv_stride
        self.iter = 5 #5
        chan = 64  #128
        fea_stride = 4  #8
        self.height = 96
        self.width =160
        # print("self.width",self.width)
        # print("self.height",self.height)
        self.alpha = 2.0    #2
        conv_stride = 9   #9

        for i in range(self.iter):
            conv_vert1 = nn.Conv2d(
                chan, chan, (1, conv_stride),
                padding=(0, conv_stride//2), groups=1, bias=False)
            conv_vert2 = nn.Conv2d(
                chan, chan, (1, conv_stride),
                padding=(0, conv_stride//2), groups=1, bias=False)

            setattr(self, 'conv_d'+str(i), conv_vert1)
            setattr(self, 'conv_u'+str(i), conv_vert2)

            conv_hori1 = nn.Conv2d(
                chan, chan, (conv_stride, 1),
                padding=(conv_stride//2, 0), groups=1, bias=False)
            conv_hori2 = nn.Conv2d(
                chan, chan, (conv_stride, 1),
                padding=(conv_stride//2, 0), groups=1, bias=False)

            setattr(self, 'conv_r'+str(i), conv_hori1)
            setattr(self, 'conv_l'+str(i), conv_hori2)

            idx_d = (torch.arange(self.height) + self.height //
                     2**(self.iter - i)) % self.height
            setattr(self, 'idx_d'+str(i), idx_d)

            idx_u = (torch.arange(self.height) - self.height //
                     2**(self.iter - i)) % self.height
            setattr(self, 'idx_u'+str(i), idx_u)

            idx_r = (torch.arange(self.width) + self.width //
                     2**(self.iter - i)) % self.width
            setattr(self, 'idx_r'+str(i), idx_r)

            idx_l = (torch.arange(self.width) - self.width //
                     2**(self.iter - i)) % self.width
            setattr(self, 'idx_l'+str(i), idx_l)

    def forward(self, x):
        print('------------------',x.shape)
        print(x.shape)
        x = x.clone()

        for direction in ['d', 'u']:
            for i in range(self.iter):
                conv = getattr(self, 'conv_' + direction + str(i))
                idx = getattr(self, 'idx_' + direction + str(i))
                x.add_(self.alpha * F.relu(conv(x[..., idx, :])))

        for direction in ['r', 'l']:
            for i in range(self.iter):
                conv = getattr(self, 'conv_' + direction + str(i))
                idx = getattr(self, 'idx_' + direction + str(i))
                x.add_(self.alpha * F.relu(conv(x[..., idx])))

        return x

上述代码中的一些超参数,是我自己设置的,便于看,免得看config了,这个的关键就是如何x.add_是怎么加的,这里面用到了一些索引,我们具体来dubug看一下
循环iter,
iter=0时
在这里插入图片描述
iter=1时
在这里插入图片描述

iter=2时
在这里插入图片描述

iter=3时
在这里插入图片描述
iter=4时
在这里插入图片描述
看到这里大家应该就明白了吧,主要实现错位的相加,依照这个顺序执行的啊,这样就实现了文中说的消息的传递,比CNN好

  • 9
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值