自己融合RGB和Depth图像

如何有效的融合融合RGB和Depth图像?
可以通过对两个通道进行卷积,然后融合,但是效果可能不太好,现在大部分都是通过注意力融合,之前自己想通过ASPP进行融合,因为ASPP可以捕捉上下文信息,但是应用了后,结果惨不忍睹,特此记录一下。

# -*- coding: utf-8 -*-
"""
.. codeauthor:: Mona Koehler <mona.koehler@tu-ilmenau.de>
.. codeauthor:: Daniel Seichter <daniel.seichter@tu-ilmenau.de>
"""
import torch
import torch.nn as nn
from torch.nn import functional as F

class ASPPConv(nn.Sequential):
    def __init__(self, in_channels, out_channels, dilation):
        modules = [
            nn.Conv2d(in_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        ]
        super(ASPPConv, self).__init__(*modules)


# 池化 -> 1*1 卷积 -> 上采样
class ASPPPooling(nn.Sequential):
    def __init__(self, in_channels, out_channels):
        super(ASPPPooling, self).__init__(
            nn.AdaptiveAvgPool2d(1),  # 自适应均值池化
            nn.Conv2d(in_channels, out_channels, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU())

    def forward(self, x):
        size = x.shape[-2:]
        for mod in self:
            x = mod(x)
        # 上采样
        return F.interpolate(x, size=size, mode='bilinear', align_corners=False)

    # 整个 ASPP 架构


class ASPP(nn.Module):
    def __init__(self, in_channels, atrous_rates, out_channels):
        super(ASPP, self).__init__()
        modules = []
        # 1*1 卷积
        modules.append(nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()))

        # 多尺度空洞卷积
        rates = tuple(atrous_rates)
        for rate in rates:
            modules.append(ASPPConv(in_channels, out_channels, rate))

        # 池化
        modules.append(ASPPPooling(in_channels, out_channels))

        self.convs = nn.ModuleList(modules)

        # 拼接后的卷积
        # 新的ASPP结构中,在concat之后再进行1*1卷积
        # self.project = nn.Sequential(
        #     nn.Conv2d(len(self.convs) * out_channels, out_channels, 1, bias=False),
        #     nn.BatchNorm2d(out_channels),
        #     nn.ReLU(),
        #     nn.Dropout(0.5))

    def forward(self, x):
        res = []
        for conv in self.convs:
            res.append(conv(x))
        # res = torch.cat(res, dim=1)
        # y = self.project(res)
        return res
class multipathfusion(nn.Module):
    def __init__(self, in_channels, atrous_rates, out_channels):
        super(multipathfusion, self).__init__()

        self.se_rgb = ASPP(in_channels, atrous_rates, out_channels)
        self.se_depth = ASPP(in_channels, atrous_rates, out_channels)
        self.project = nn.Sequential(
            nn.Conv2d(5 * out_channels, out_channels, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Dropout(0.5))
    def forward(self, rgb, depth):
        res_rgb = self.se_rgb(rgb)
        res_depth = self.se_depth(depth)
        out = zip(res_rgb , res_depth)
        _res = []
        for i, j in out:
            output = i + j
            _res.append(output)
            # output = [i + j for i, j in zip(res_rgb, res_depth)]
        res = torch.cat(_res, dim=1)
        y = self.project(res)
        return y


# def main():
#     in_channels = 64
#     atrous_rates = [6,12,18]
#     out_channels = 64

#     model = multipathfusion(
#         in_channels, atrous_rates, out_channels)
#     print("------------------------------------------------------------------------------")
#     print(model)

#     model.eval()
#     rgb_image = torch.randn(1, 64, 320, 240)
#     depth_image = torch.randn(1, 64, 320, 240)

#     print("------------------------------------------------------------------------------")
#     # print(rgb_image)

#     with torch.no_grad():
#         out = model(rgb_image,depth_image)


#     print("------------------------------------------------------------------------------")
#     print(out)
#     print(out.size())
#     # print(res_depth)

# # torch.Size([1, 64, 320, 240])


# if __name__ == '__main__':
#     main()

在模型中加入修改的ASPP,然后在args中修改模块名称,运行train.py

if fuse_depth_in_rgb_encoder == 'MPF':
            self.se_layer0 = multipathfusion(
                64,[6,12,18],64)
            self.se_layer1 = multipathfusion(
                self.encoder_rgb.down_4_channels_out,
                [6,12,18],64)
            self.se_layer2 = multipathfusion(
                self.encoder_rgb.down_8_channels_out,
                [6,12,18],128)
            self.se_layer3 = multipathfusion(
                self.encoder_rgb.down_16_channels_out,
                [6,12,18],256)
            self.se_layer4 = multipathfusion(
                self.encoder_rgb.down_32_channels_out,
                [6,12,18],512)

结果:效果出奇的差在这里插入图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值