最近一直在做可见光图像与红外光图像的配准和融合研究,阅读了很多文献后发现在这一领域有很多很不错的工作都没有公开代码,因此希望写一整个系列来分享一些比较有趣的研究,并尝试复现这些论文。
这篇推文要复现的论文是Huafeng Li等人的《Feature dynamic alignment and refinement for infrared-visible image fusion: Translation robust fusion》,作者在文章中给出了GitHub链接,但是距文章发布已经过去了很久了,仓库里还是空空如也,最近的工作中需要用到这篇论文的部分架构,所以先来尝试复现一下。如果后期更新了仓库,有任何错误会及时修改。
由于篇幅限制,本篇暂时只复现论文中的Cross-Modulation Feature Fxtraction Module (CMFEM)交叉调制模块,其余的模块将在之后的推文中复现。
目录
文章介绍
- 标题:《Feature dynamic alignment and refinement for infrared-visible image fusion: Translation robust fusion》
- 单位:香港科技大学信息工程与自动化学院、哈尔滨工业大学(深圳)、深圳视觉目标检测与识别重点实验室
- 原文链接:https://www.sciencedirect.com/science/article/pii/S1566253523000519
- GitHub链接:https://github.com/lhf12278/RFVIF
这篇论文的工作个人认为很有意义,作者将VI和IR图片配准的过程集成进了网络当中,使得整个网络可以在具有轻微位移的尚未配准的VI和IR图像上进行融合,最终输出进行配准的融合结果。虽然只能实现在轻微位移情况下的融合,但这样的设计思想更适合实际的部署,是一项值得继续跟进的研究。
模型架构
完整模型架构
以上是模型的整体架构图,主要包括:Cross-Modulation Feature Extraction Module (CMFEM) 交叉调制模块, Feature Dynamic Alignment Module (FDAM) 动态对齐模块, Multi-Grained Feature Refinement Module (MGFRM) 多粒度细化模块,Pyramid Feature Fusion Module (PFFM) 金字塔融合模块。
交叉调制模块CMFEM架构
本篇推文主要对交叉调制模块CMFEM进行复现,以下是该模块的架构图
交叉调制模块CMFEM主要有两个核心机制:多尺度残差块MSRB(左)和交叉调制机制(右)。其中,多尺度残差块MSRB的作用是分别通过3x3和5x5卷积核来提取不同尺度的特征并进行聚合,之后使用跳跃连接将原始特征和聚合特征进行相加,在避免模型退化的同时学习到多尺度的特征。交叉调制机制的作用是通过提取聚合后的多尺度特征,并进行交叉地特征分离,之后再通过点积和逐像素相加的方式进一步将输入的图像同源化。
事实上,整个CMFEM模块的本质就是通过特征提取的方式,将异源的红外和可见光图像进行同源化的过程,并且最终得到的高级特征也有利于之后的对齐操作。然而,由于原论文中并没有给出具体细节,因此主要依据以上思想进行实现。
模型实现
核心库导入
由于本篇推文只对架构进行复现,不涉及训练和测试等工作,因此只需要导入定义架构时需要用到的核心代码。
import torch
import torch.nn as nn
from typing import List, Tuple, Union
多尺度残差块MSRB的实现
MSRB模块(Multi-scale Residual Block),即多尺度残差模块,该模块分别通过一个3x3卷积核和5x5卷积核来提取不同尺度的特征并拼接,之后再使用残差结构进行逐像素相加。原论文中并没有给出具体的通道数,因此基于模块的设计思想,对整个模块进行以下设置:
- 整个MSRB模块不改变特征通道数,与输入时的通道数保持一致;
- MSRB模块3x3和5x5卷积核不进行特征图尺寸的变换,仍然保持输入时的尺寸;
- 3x3和5x5卷积核得到的特征进行拼接后,通过1x1卷积核通道缩减,使其与跳跃连接的通道数相等。
以上三点设置主要考虑到:如果在MSRB模块中进行特征图尺寸的缩进,最终会导致特征图过小;如果在MSRB模块中进行通道数的扩展,会导致通道数过多。
MSRB模块实现的具体代码如下:
class MSRB(nn.Module):
def __init__(self, num_ch):
super().__init__()
self.in_ = num_ch
# 保持特征图尺寸不变
self.res_3 = nn.Sequential(nn.Conv2d(num_ch, num_ch, 3, 1, 1)
,nn.BatchNorm2d(num_ch)
,nn.ReLU(True))
self.res_5 = nn.Sequential(nn.Conv2d(num_ch, num_ch, 5, 1, 2)
,nn.BatchNorm2d(num_ch)
,nn.ReLU(True))
self.skip_ch = nn.Conv2d(num_ch, num_ch, 1, 1)
self.fea_ch = nn.Conv2d(2*num_ch, num_ch, 1, 1)
self.relu = nn.ReLU()
def forward(self, x):
x_1 = x # 跳跃连接
x_2 = self.res_3(x) # 3x3的卷积块
x_3 = self.res_5(x) # 5x5的卷积块
x_cat = torch.cat((x_2, x_3), dim=1) # 对不同尺度的特征进行拼接
x_muti = self.fea_ch(x_cat) # 拼接后特征的通道数减半
output = self.relu(x_1 + x_muti) # 逐像素相加后进行激活
return output
对MSRB模块进行测试,测试代码如下:
msrb = MSRB(64) ### 输入通道数与输出通道数一致
data = torch.rand(10, 64, 224, 224)
msrb(data).shape
# 测试结果:torch.Size([10, 64, 224, 224])
注意:在原论文中图像数据是直接输入到MSRB模块中的,但为了使代码更加优雅,输入的图像数据通过1x1卷积核进行通道拓展后再输入到MSRB模块中,这部分内容将在之后进行介绍。
交叉调制机制的实现
Cross Modulation机制,即交叉调制机制,该机制通过拼接输入特征再交叉地进行特征提取,并之前得到的特征进行点击后逐像素相加。基于模型的设计思想,对整个机制进行以下设置:
- Sconv层中进行了特征图尺寸减半操作;
- 所有的Convs层都不改变特征图尺寸;
- 整个Cross Modulation机制最终会使输入通道数翻倍;
以上三点设置主要是为了在交叉调制机制中进行特征图尺寸缩减和通道数翻倍,由于交叉调制机制的数量较少,因此不会存在特征图过小和通道数过多的情况。
交叉调制机制实现的具体代码如下:
class Cross_Mod(nn.Module):
def __init__(self, in_ch, out_ch):
super().__init__()
self.in_ = in_ch
self.out_ = out_ch
self.sconv = nn.Conv2d(in_ch, out_ch, 3, 2) # 进行特征图尺寸减半
self.conv_cat = nn.Conv2d(2*out_ch, 2*out_ch, 3, 1, 1) # 用于拼接后的卷积核
self.convs = nn.Sequential(nn.Conv2d(out_ch, out_ch, 3, 1, 1)
,nn.Sigmoid())
def forward(self, ir, vi):
ir = self.sconv(ir)
vi = self.sconv(vi)
img_cat = torch.cat((vi, ir), dim=1)
img_conv = self.conv_cat(img_cat)
# 平均拆分拼接后的特征
split_ir = img_conv[:,self.out_:] # 前一半特征
split_vi = img_conv[:,:self.out_] # 后一半特征
# IR图像处理流程
ir_1 = self.convs(split_ir)
ir_2 = self.convs(split_ir)
ir_mul = torch.mul(ir, ir_1)
ir_add = ir_1 + ir_2
# VI图像处理流程
vi_1 = self.convs(split_vi)
vi_2 = self.convs(split_vi)
vi_mul = torch.mul(vi, vi_1)
vi_add = vi_1 + vi_2
return ir_add,vi_add
对交叉调制机制进行测试,测试代码如下:
cross_mod = Cross_Mod(64, 128) # 输出通道数翻倍
data = torch.rand(10, 64, 224, 224)
cross_mod(data, data)[0].shape # 返回IR和VI
# 测试结果:torch.Size([10, 128, 111, 111])
串联模块的实现
为了可以按照原文结构将MSRB模块和交叉调制机制作为一个整体,这里定义一个串联模块,该模块可以设定MSRB模块的数量,并返回经过MSRB模块和交叉调制机制后的VI和IR图像的结果。
class MSRB_Cross(nn.Module):
def __init__(self, num, in_ch, out_ch):
'''
num:MSRB的数量
in_ch:MSRB的输入、输出和交叉调制机制的输入通道数
out_ch:交叉调制机制的输出通道数
'''
super().__init__()
self.block = nn.Sequential() # 定义一个空的Sequential存放MSRB模块
for i in range(num):
# 根据设定的数量将MSRB模块加入到block中
self.block.add_module(f'MSRB_{i+1}', MSRB(in_ch))
self.cross = Cross_Mod(in_ch, out_ch) # 定义交叉调制机制
def forward(self, ir, vi):
ir_msrb = self.block(ir)
vi_msrb = self.block(vi)
ir_output,vi_output = self.cross(ir_msrb, vi_msrb)
return ir_output,vi_output
对串联模块进行测试,测试代码如下:
msrb_cross = MSRB_Cross(3, 64, 128) # 定义了3个MSRB模块
data = torch.rand(10, 64, 224, 224)
msrb_cross(data, data)[0].shape # 返回IR和VI
# 测试结果:torch.Size([10, 128, 111, 111])
交叉调制模块CMFEM
交叉调制模块(Cross-Modulation Feature Extraction Module),即交叉调制模块,是根据原文结构对MSRB模块和交叉调制机制的整体定义,根据模块的设计思想,对整个模块进行以下设置:
- MSRB模块保持通道数不变,使用Cross_Mod模块扩展通道数;
- 由于原论文架构中没有在输入图像与架构之间添加通道转换的操作,因此在CMFEM类中先对输入图像进行通道转换,使其通道数转化为第一层的输入通道数。
在设置中最核心的改变是在输入图像与交叉调制模块之间设置了1x1卷积层来扩展通道数,使得输入的3维图像可以转化为给定的输入特征,以此方便整个模块的定义。
交叉调制模块CMFEM实现的具体代码如下:
class CMFEM(nn.Module):
def __init__(self, params: List[Union[Tuple[int, int, int]]]):
'''
params:输入列表,列表中存放多个元组,即List[Union[Tuple[int, int, int]]],
通过List中元组的个数来表示共有多少个MSRB_Cross层;
Tuple[int, int, int]:第1个int表示该层中MSRB模块的数量,第2个int表示每个
MSRB的输入(输出)通道数,第3个int表示每个Cross_Mod的输出通道数;
输出通道数:假设输入的参数为[(3, 64, 128)],则MSRB的输入通道为64,输出通道
也为64,Cross_Mod的输入通道数为64,输出通道数为128
'''
super().__init__()
self.num_layers = len(params)
self.img_ch = nn.Conv2d(3, params[0][1], 1, 1)
self.cmfem = nn.Sequential()
for i in range(len(params)):
self.cmfem.add_module(f'layers_{i+1}', MSRB_Cross(params[i][0] # MSRB数量
,params[i][1] # in_ch
,params[i][2])) # out_ch
def forward(self, ir, vi):
ir_output = self.img_ch(ir) # 对原始IR图像进行通道转换
vi_output = self.img_ch(vi) # 对原始VI图像进行通道转换
for i in range(self.num_layers):
# 需要循环传入参数,否则会报错
ir_output, vi_output = self.cmfem[i](ir_output, vi_output)
return ir_output,vi_output
对交叉调制模块CMFEM进行测试,测试代码如下:
params = [(3,64,128), (2,128,256), (2,256,512)] # 原论文架构
cmfem = CMFEM(params)
data = torch.rand(10, 3, 224, 224)
cmfem(data,data)[0].shape # 返回IR和VI
# 测试结果:torch.Size([10, 512, 27, 27])
以上就是交叉调制模块CMFEM的定义,在多次测试后可以保证代码能够正常运行。之后会继续进行这篇论文的其他架构的复现,也希望作者大大可以早日将源码上传到GitHub,有任何错误和问题以及想法都可以在评论区交流。
希望这个系列可以一直更新下去!!!😊😊😊