如何有效的融合融合RGB和Depth图像?
可以通过对两个通道进行卷积,然后融合,但是效果可能不太好,现在大部分都是通过注意力融合,之前自己想通过ASPP进行融合,因为ASPP可以捕捉上下文信息,但是应用了后,结果惨不忍睹,特此记录一下。
# -*- coding: utf-8 -*-
"""
.. codeauthor:: Mona Koehler <mona.koehler@tu-ilmenau.de>
.. codeauthor:: Daniel Seichter <daniel.seichter@tu-ilmenau.de>
"""
import torch
import torch.nn as nn
from torch.nn import functional as F
class ASPPConv(nn.Sequential):
def __init__(self, in_channels, out_channels, dilation):
modules = [
nn.Conv2d(in_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU()
]
super(ASPPConv, self).__init__(*modules)
# 池化 -> 1*1 卷积 -> 上采样
class ASPPPooling(nn.Sequential):
def __init__(self, in_channels, out_channels):
super(ASPPPooling, self).__init__(
nn.AdaptiveAvgPool2d(1), # 自适应均值池化
nn.Conv2d(in_channels, out_channels, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU())
def forward(self, x):
size = x.shape[-2:]
for mod in self:
x = mod(x)
# 上采样
return F.interpolate(x, size=size, mode='bilinear', align_corners=False)
# 整个 ASPP 架构
class ASPP(nn.Module):
def __init__(self, in_channels, atrous_rates, out_channels):
super(ASPP, self).__init__()
modules = []
# 1*1 卷积
modules.append(nn.Sequential(
nn.Conv2d(in_channels, out_channels, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU()))
# 多尺度空洞卷积
rates = tuple(atrous_rates)
for rate in rates:
modules.append(ASPPConv(in_channels, out_channels, rate))
# 池化
modules.append(ASPPPooling(in_channels, out_channels))
self.convs = nn.ModuleList(modules)
# 拼接后的卷积
# 新的ASPP结构中,在concat之后再进行1*1卷积
# self.project = nn.Sequential(
# nn.Conv2d(len(self.convs) * out_channels, out_channels, 1, bias=False),
# nn.BatchNorm2d(out_channels),
# nn.ReLU(),
# nn.Dropout(0.5))
def forward(self, x):
res = []
for conv in self.convs:
res.append(conv(x))
# res = torch.cat(res, dim=1)
# y = self.project(res)
return res
class multipathfusion(nn.Module):
def __init__(self, in_channels, atrous_rates, out_channels):
super(multipathfusion, self).__init__()
self.se_rgb = ASPP(in_channels, atrous_rates, out_channels)
self.se_depth = ASPP(in_channels, atrous_rates, out_channels)
self.project = nn.Sequential(
nn.Conv2d(5 * out_channels, out_channels, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(),
nn.Dropout(0.5))
def forward(self, rgb, depth):
res_rgb = self.se_rgb(rgb)
res_depth = self.se_depth(depth)
out = zip(res_rgb , res_depth)
_res = []
for i, j in out:
output = i + j
_res.append(output)
# output = [i + j for i, j in zip(res_rgb, res_depth)]
res = torch.cat(_res, dim=1)
y = self.project(res)
return y
# def main():
# in_channels = 64
# atrous_rates = [6,12,18]
# out_channels = 64
# model = multipathfusion(
# in_channels, atrous_rates, out_channels)
# print("------------------------------------------------------------------------------")
# print(model)
# model.eval()
# rgb_image = torch.randn(1, 64, 320, 240)
# depth_image = torch.randn(1, 64, 320, 240)
# print("------------------------------------------------------------------------------")
# # print(rgb_image)
# with torch.no_grad():
# out = model(rgb_image,depth_image)
# print("------------------------------------------------------------------------------")
# print(out)
# print(out.size())
# # print(res_depth)
# # torch.Size([1, 64, 320, 240])
# if __name__ == '__main__':
# main()
在模型中加入修改的ASPP,然后在args中修改模块名称,运行train.py
if fuse_depth_in_rgb_encoder == 'MPF':
self.se_layer0 = multipathfusion(
64,[6,12,18],64)
self.se_layer1 = multipathfusion(
self.encoder_rgb.down_4_channels_out,
[6,12,18],64)
self.se_layer2 = multipathfusion(
self.encoder_rgb.down_8_channels_out,
[6,12,18],128)
self.se_layer3 = multipathfusion(
self.encoder_rgb.down_16_channels_out,
[6,12,18],256)
self.se_layer4 = multipathfusion(
self.encoder_rgb.down_32_channels_out,
[6,12,18],512)
结果:效果出奇的差