MsTGANet: Automatic Drusen Segmentation From Retinal OCT Images(从视网膜OCT图像自动分割玻璃疣病变)
目录
一、摘要
研究背景:Drusen(玻璃疣)被认为是AMD(黄斑变性)诊断的标志,也是 AMD发展的重要危险因素。因此,准确分割视网膜OCT图像中的玻璃疣对AMD的早期诊断至关重要。
研究现状:1. 在视网膜OCT图像中,由于玻璃疣的大小和形状变化很大,边界模糊,以及斑点噪声干扰,玻璃疣分割仍然是非常具有挑战性。2. 缺乏具有像素级标注的OCT数据集。
主要工作:提出了一种新的多尺度Transformer全局注意力网络(MsTGANet)用于视网膜OCT图像中的玻璃疣分割。
- 1. 在基于U形结构的MsTGANet中,设计了一种新的多尺度Transformer非局部(MsTNL)模块,并将其插入到编码器路径的顶部(插入位置),旨在捕获编码器不同层具有远程依赖性的多尺度非局部特征(子模块1的设计目的)。
- 2. 同时,在编码器和解码器之间(位置)提出了一种新的多语义全局通道和空间联合注意模块(MsGCS),引导模型融合不同的语义特征(子模块2的设计目的),从而提高模型学习多语义全局上下文信息的能力。
半监督策略:为了缓解标记数据不足的问题,本文提出了一种基于伪标记数据增强策略的半监督MsTGANet(Semi-MsTGANet),它可以利用大量的未标记数据进一步提高分割性能。
研究成果:实验结果表明,我们提出的方法取得了更好的分割精度比其他最先进的基于CNN的方法。
二、创新点
1) 提出了一种新的MsTNL模块,并将其嵌入到编码器路径的顶部,以捕获编码器中不同层的具有长程依赖性的多尺度非局部特征。
2) 为了提高模型学习多语义全局上下文特征的能力,提出了一种新的MsGCS模块,并插入编码器和解码器之间。
3) 结合MsTNL和MsGCS模块,设计了一种基于U形结构的MsTGANet网络,并将其应用于OCT图像中的玻璃疣分割。
4) 进一步提出了一种新的半监督版本的MsTGANet(半监督MsTGANet)的基础上伪标记的数据增强策略,它可以利用大量的未标记的数据,以进一步提高分割精度。
三、MsTGANet的具体现实
MsTGANet的结构图如下:
3.1 MsTGANet网络结构概述
主要由Encoder路径、MsTNL模块、MsGCS模块和Dncoder路径四部分组成。
MsTNL模块作用:将来自不同Encoder(编码器)层的具有不同尺度信息的特征图作为MsTNL的输入,捕获具有远程依赖关系的多尺度非局部特征。
MsGCS模块作用:取代(unet中的)跳连接,引导模型融合多语义全局上下文特征,从而提高模型学习全局显著特征的能力,同时抑制不相关局部特征的干扰。
3.2 Encoder和Dncoder
与unet的编码器与解码器一致。
作用:特征提取和加强特征。
结构:编码器主要包含5个块,除了第一个块只有两个卷积层,其他每个块包含一个MaxPool操作,然后是两个卷积层。解码器路径包含5个块,每个块主要包含一个上采样层、特征融合操作和两个卷积层。
3.3 MsTNL模块
MsTNL模块由多头自注意编码器和多头自注意解码器两部分组成。
目的:在OCT图像中,玻璃疣的病理表现非常复杂,尤其是在大小和形状上,而且还伴随着许多其他病理和噪声干扰,因此提高网络学习多尺度非局部特征的能力对于提高玻璃疣分割的精度至关重要。
Q:在充斥其他病理和噪声干扰的OCT图像,为什么多尺度非局部特征能够提高玻璃疣分割的精度?
A:首先,非局部操作可以提取图像像素间的远程依赖关系,简单来说,非局部操作可以使相似的像素值被整合为一类,最终得到的新的特征图的像素值是按类别划分的。这样就可以将玻璃疣与其他病理特征分割开来。
其次,多尺度,输入的特征图来源于多个不同的尺度大小的特征图像,保留了更多的特征细节。利于提高分割的精度。
(1) 多头自注意编码器
目的:在 引导下,提取 中具有长依赖关系的多尺度非局部特征。
输入:Encoder路径上不同的阶段的特征图 和顶层图像 (原图)。
过程:
都通过一个maxpooling和一个3x3卷积,然后通过逐元素相加获得 ,与的大小和通道数一致。
Q:为什么要设置下采样?
A:为了保证输入图像的大小一致。
1. 获取Q,K,V
通过1x1的卷积得到query值Q, 分别通过两个1x1的卷积得到key值、value值K、V,定义如下:
2. 进行位置特征编码
为了帮助网络聚焦位置特征间的远程依赖关系,模块内分别在垂直方向和水平方向设置了两个学习向量、以促进对位置特征的学习。,通过reshape操作再逐元素相加得到特征PE。有关position encode的定义如下:
encode是指参数向量学习的过程。
3. 获取注意力图att
首先,通过矩阵乘法计算Q与K之间的相似度矩阵E,获得多尺度非局部空间相关性权重。然后,PE和Q进行矩阵乘法获取Q中特征的垂直和水平方向的位置相关性权重矩阵EP。最后,通过将E和EP相加,再通过一个Softmax函数获得注意力图Att。相关定义如下:
Q:位置相关性权重矩阵EP的作用大吗?
其中,◦ 是矩阵乘法运算。
4. 获取非局部空间响应
对注意力图Att和对应的value值V进行矩阵乘法,以获得具有强全局语义的多尺度非局部空间特征 。(经典的多头注意力机制特征融合操作)
non-local spatial response非局部空间响应,其中,response应该是代指特征的意思。
3. 获取多尺度非局部特征图
最后,通过 和加权的 进行逐元素相加获得具有远程依赖性的多尺度非局部特征图 ,定义如下所示:
其中,γ是初始化为0的可学习参数,并且在训练过程中逐渐调整以按可学习方式分配 的权重。
总结:该多头自注意编码器模块很大程度上参考NLNet,是非局部操作的延伸,继承并超越NLNet,但很明显也继承了NLNet计算消耗量大的缺点,不知道是否也有查询无关的问题。
代码如下:
class SPP_Q(nn.Module):
def __init__(self,in_ch,out_ch,down_scale,ks=3):
super(SPP_Q, self).__init__()
self.Conv = nn.Sequential(
nn.Conv2d(in_ch, out_ch, kernel_size=ks, stride=1, padding=ks // 2,bias=False),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True)
)
self.Down = nn.Upsample(scale_factor=down_scale,mode="bilinear")
def forward(self, x):
x_d = self.Down(x)
x_out = self.Conv(x_d)
return x_out
class Encoder_Pos(nn.Module):
# 512
def __init__(self, n_dims, width=32, height=32, filters=[32,64,128,256]):
super(Encoder_Pos, self).__init__()
print("================= Multi_Head_Encoder =================")
# 512
self.chanel_in = n_dims
# 学习向量
self.rel_h = nn.Parameter(torch.randn([1, n_dims//8, height, 1]), requires_grad=True)
self.rel_w = nn.Parameter(torch.randn([1, n_dims//8, 1, width]), requires_grad=True)
# 多尺度输入
self.SPP_Q_0 = SPP_Q(in_ch=filters[0],out_ch=n_dims,down_scale=1/16,ks=3)
self.SPP_Q_1 = SPP_Q(in_ch=filters[1],out_ch=n_dims,down_scale=1/8,ks=3)
self.SPP_Q_2 = SPP_Q(in_ch=filters[2],out_ch=n_dims,down_scale=1/4,ks=3)
self.SPP_Q_3 = SPP_Q(in_ch=filters[3],out_ch=n_dims,down_scale=1/2,ks=3)
# q,k,v
self.query_conv = nn.Conv2d(in_channels = n_dims , out_channels = n_dims//8 , kernel_size= 1)
self.key_conv = nn.Conv2d(in_channels = n_dims , out_channels = n_dims//8 , kernel_size= 1)
self.value_conv = nn.Conv2d(in_channels = n_dims , out_channels = n_dims , kernel_size= 1)
self.gamma = nn.Parameter(torch.zeros(1))
self.softmax = nn.Softmax(dim=-1)
# 特征FT,(512,32,32),
def forward(self, x,x_list):
m_batchsize, C, width, height = x.size()
# 特征FA,[512,32,32]
Multi_X = self.SPP_Q_0(x_list[0]) + self.SPP_Q_1(x_list[1]) + self.SPP_Q_2(x_list[2]) + self.SPP_Q_3(x_list[3])
# 特征Q,由FA通过卷积再整形而来,[2,1024,64]
proj_query = self.query_conv(Multi_X).view(m_batchsize, -1, width * height).permute(0, 2, 1)
# 特征K,[2,1024,64]
proj_key = self.key_conv(x).view(m_batchsize, -1, width * height)
# 特征E
energy_content = torch.bmm(proj_query, proj_key)
# 位置特征编码
content_position = (self.rel_h + self.rel_w).view(1, self.chanel_in//8, -1)
content_position = torch.matmul(proj_query,content_position)
energy = energy_content + content_position
# att特征,[2, 1024, 1024]
attention = self.softmax(energy)
# 特征V,[2,512,1024]
proj_value = self.value_conv(x).view(m_batchsize, -1, width * height)
out = torch.bmm(proj_value, attention.permute(0, 2, 1))
out = out.view(m_batchsize, C, width, height)
out = self.gamma * out + x
return out, attention
(2) 多头自注意解码器
目的:基于多尺度非局部特征 的指导,进一步提取顶层特征图 中包含的强语义位置自相关信息。
输入:顶层特征图 和多尺度非局部特征 。
过程:
1. 获取,,
采用三个1x1的卷积作为Q,K,V三个分支,分别将 转化为,将 转化为,,定义如下:
2. 获取的位置编码
与多头自注意编码器中一样,分别设置两个学习向量、,分别从垂直和水平方向对中的强语义特征位置进行编码。再经过reshape操作和逐元素相加得到位置特征。公式定义如下:
3. 获取注意力图
与编码器中的步骤一致。最终得到捕获了空间相关性和位置相关性的注意力图 ,。定义如下:
其中,◦ 是矩阵乘法运算。
4. 获取非局部特征
注意图和进行矩阵乘法,再经过一个Reshape操作,获得具有强语义位置和远程依赖关系信息的空间特征。定义如下:
5. 获取多尺度强语义非局部特征图
最后,和进行残差求和,获得最终的具有长依赖性的多尺度强语义非局部特征图 ,定义如下所示:
其中,γ是初始化为0的可学习参数,并且在训练过程中逐渐调整以按可学习方式分配 的权重。
代码如下:
class Decoder_Pos(nn.Module):
def __init__(self, n_dims, width=32, height=32):
super(Decoder_Pos, self).__init__()
print("================= Multi_Head_Decoder =================")
self.chanel_in = n_dims
self.rel_h = nn.Parameter(torch.randn([1, n_dims//8, height, 1]), requires_grad=True)
self.rel_w = nn.Parameter(torch.randn([1, n_dims//8, 1, width]), requires_grad=True)
self.query_conv = nn.Conv2d(in_channels=n_dims, out_channels=n_dims // 8, kernel_size=1)
self.key_conv = nn.Conv2d(in_channels=n_dims, out_channels=n_dims // 8, kernel_size=1)
self.value_conv = nn.Conv2d(in_channels=n_dims, out_channels=n_dims, kernel_size=1)
self.gamma = nn.Parameter(torch.zeros(1))
self.softmax = nn.Softmax(dim=-1)
def forward(self, x,x_encoder):
m_batchsize, C, width, height = x.size()
proj_query = self.query_conv(x).view(m_batchsize, -1, width * height).permute(0, 2, 1)
proj_key = self.key_conv(x_encoder).view(m_batchsize, -1, width * height)
energy_content = torch.bmm(proj_query, proj_key)
content_position = (self.rel_h + self.rel_w).view(1, self.chanel_in//8, -1)
content_position = torch.matmul(proj_query,content_position)
energy = energy_content+content_position
attention = self.softmax(energy)
proj_value = self.value_conv(x_encoder).view(m_batchsize, -1, width * height)
out = torch.bmm(proj_value, attention.permute(0, 2, 1))
out = out.view(m_batchsize, C, width, height)
out = self.gamma * out + x
return out, attention
3.4 MsGCS模块
问题:编码器和解码器之间的简单跳连接忽略了全局信息,并且可能引入来自局部不相关特征的干扰。
解决方法:提出了一种新的多语义全局通道和空间联合注意力模块(MsGCS),以取代简单的跳连接。
目的:引导模型学习在通道和空间维度方向上的多语义全局上下文特征。
输入:(Encode路径)高分辨率弱语义特征的特征图 和(Decode路径)低分辨率强语义特征的上采样特征图 。
过程:首先,将 和 通过拼接操作进行融合,得到新的特征图通道数变为C1+C2,再通过一个卷积,将通道数压缩为1,以获得信道维度上的多语义全局特征图。定义如下:
然后,进一步自适应地捕获空间维度上的多语义全局特征,通过设置一个全局可学习权重矩阵 与 相乘,然后进行批量归一化操作(BN)和sigmoid激活,得到一个多语义全局特征注意力图 。定义如下:
其中, 和 是水平和垂直方向上的空间特征位置相关向量。
Q:为什么要设置学习权重矩阵,传统的空间注意力操作不应是空间注意力特征与原图像融合吗?
A:MsGCS模块通过压缩通道,得到一个空间注意力特征图,注意该空间注意力特征图已经是一个空间特征的权重矩阵,所以作者希望再设置一个学习权重矩阵来加强空间特征的学习(也可能能够提高训练速度)。
代码如下:
import torch
import torch.nn as nn
class MsGCS(nn.Module):
def __init__(self,F_g,F_l,F_int,size):
super(MsGCS, self).__init__()
print("=============== MsGCS ===============")
self.Conv = nn.Sequential(
nn.Conv2d(F_g+F_l,F_int,kernel_size=1),
nn.BatchNorm2d(F_int),
nn.ReLU(inplace=True),
nn.Conv2d(F_int,1,kernel_size=1),
nn.BatchNorm2d(1),
nn.ReLU(inplace=True)
)
self.rel_h = nn.Parameter(torch.randn([1, 1, size[0], 1]), requires_grad=True)
self.rel_w = nn.Parameter(torch.randn([1, 1, 1, size[1]]), requires_grad=True)
self.bn = nn.BatchNorm2d(1)
self.active = nn.Sigmoid()
def forward(self,g,x):
x_Multi_Scale = self.Conv(torch.cat([g,x],dim=1))
content_position = self.rel_h+self.rel_w
x_att_multi_scale = self.active(self.bn(content_position*x_Multi_Scale))
return x_att_multi_scale*x
3.5 损失函数
采用了一个包含Dice损失和二进制交叉熵损失的联合损失函数 来指导模型的训练。联合损失函数定义如下:
其中,X和Y表示分割结果和对应的ground truth,h和w表示X和Y中的像素的坐标。
3.6 半监督MsTGANet框架
问题:缺乏具有像素级标注的OCT数据集。
解决方法:提出了一种基于伪标记数据增强策略的半监督MsTGANet框架,进一步提高了分割精度。
半监督MsTGANet框架基本步骤:
1) 首先在全监督目标函数的指导下,基于标记数据训练MsTGANet。
2) 采用上述预训练的MsTGANet对大量未标记数据中的玻璃疣进行分割,并将分割结果作为未标记数据对应的伪标签。
3) 将大量带有伪标签的未标记数据与标记数据混合,基于混合监督损失函数对MsTGANet进行再训练。
混合监督损失函数定义如下:
其中, 和 分别为原始数据和对应的标签, 和 分别表示原始未标记数据和预训练生成的伪标签。
半监督的训练方式,简单来说是运用预训练好的全监督网络对原始未标记数据生成伪标签,然后将带有伪标签的未标记数据与标记数据混合再进行训练。
四、实验和结果
数据集:UCSD数据集(8616张OCT B超扫描图像)。
评估指标:Jaccard指数(Jac)、Dice相似系数(DSC)、precision (Pre)和Pearson积矩相关系数(Ppmcc)。定义如下:
4.1 定性分析
半监督MsTGANet、MsTGANet与经典网络Unet、CE-Net对比,在存在散斑噪声干扰和大小或形状变化的情况下的7个分割结果,如下图所示。
如图所示,半监督MsTGANet实现了更好的分割性能。作者分析,在U-Net和CE-Net中,跳过连接的特征映射通过简单的拼接操作直接与上采样的特征映射合并,很难避免不相关的局部信息和散斑噪声的干扰,导致误报。
结论:本文提出的MsTGANet和半监督MsTGANet,在散斑噪声干扰和一些尺寸或形状变化的影响下仍能取得更好的分割性能,证明了我们提出 的方法的有效性和鲁棒性。
4.2 定量评价
评估标准:Jac, DSC, Pre和Ppmcc。
对比网络:FCN、U-Net、FastFCN、Attention UNet(Att-UNet)、PsPNet、DeepLabV3、CENet、DANet、CPFNet、GCN 、R2UNet、unet++和HRSegNet。
由表所示,作者提出的半监督MsTGANet在所有的评估指标中达到了最优。
对比网络劣势分析:
1. FCN实现了最差的结果,FCN由于是基于VGG的顶层的特征作为分割目标,这可能导致一些小尺寸玻璃疣的特征信息的丢失。
2. FastFCN理同FCN的劣势原因。
3. PsPNet和DeepLabV3都通过引入特征金字塔模块来捕获多尺度特征来提高分割性能, 这也证明了多尺度特征有利于提高分割性能。
4.3 消融实验
结果表明,所提出的MsGCS模块有利于提高模型的性能,所提出的MsTGANet(UNet+MsTNL+ MsGCS)可以显著提高OCT图像的图像分割精度。
五、结论
1. MsTNL模块是一个多头注意力机制 + NL操作,它是NL模块的延伸,但是不知道是否继承了NLNet查询位置无关的缺点?此外MsTNL模块内部多处使用了学习权重矩阵,该学习权重矩阵加强注意力特征程度如何不得而知。(MsTNL模块导致的训练缓慢,实际训练过程中可见)
2. MsGCS模块是一个空间注意力模块,同样引用了学习权重矩阵。
3. 半监督的训练方式,简单来说是运用预训练好的全监督网络对原始未标记数据生成伪标签,然后将带有伪标签的未标记数据与标记数据混合再进行训练。
六、代码实现
MsTGANet.py
import torch
import torch.nn as nn
from torchsummary import summary
from models.MsGCS import MsGCS as MsM
from models.MsTNL import MsTNL
class conv_block(nn.Module):
def __init__(self, ch_in, ch_out):
super(conv_block, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True),
nn.BatchNorm2d(ch_out),
nn.ReLU(inplace=True),
nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1, bias=True),
nn.BatchNorm2d(ch_out),
nn.ReLU(inplace=True)
)
def forward(self, x):
x = self.conv(x)
return x
class up_conv(nn.Module):
def __init__(self, ch_in, ch_out):
super(up_conv, self).__init__()
self.up = nn.Sequential(
nn.Upsample(scale_factor=2),
nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True),
nn.BatchNorm2d(ch_out),
nn.ReLU(inplace=True)
)
def forward(self, x):
x = self.up(x)
return x
class MsTGANet(nn.Module):
def __init__(self, in_channels=3, num_classes=1, feature_scale=2):
super(MsTGANet, self).__init__()
print("================ MsTGANet ================")
filters = [64, 128, 256, 512, 1024]
filters = [int(x / feature_scale) for x in filters]
self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
self.Conv1 = conv_block(ch_in=in_channels, ch_out=filters[0])
self.Conv2 = conv_block(ch_in=filters[0], ch_out=filters[1])
self.Conv3 = conv_block(ch_in=filters[1], ch_out=filters[2])
self.Conv4 = conv_block(ch_in=filters[2], ch_out=filters[3])
self.Conv5 = conv_block(ch_in=filters[3], ch_out=filters[4])
self.trans = MsTNL(train_dim=512, filters=filters)
self.Up5 = up_conv(ch_in=filters[4], ch_out=filters[3])
self.Att5 = MsM(F_g=filters[3], F_l=filters[3], F_int=filters[2], size=(64, 64))
self.Up_conv5 = conv_block(ch_in=filters[4], ch_out=filters[3])
self.Up4 = up_conv(ch_in=filters[3], ch_out=filters[2])
self.Att4 = MsM(F_g=filters[2], F_l=filters[2], F_int=filters[1], size=(128, 128))
self.Up_conv4 = conv_block(ch_in=filters[3], ch_out=filters[2])
self.Up3 = up_conv(ch_in=filters[2], ch_out=filters[1])
self.Att3 = MsM(F_g=filters[1], F_l=filters[1], F_int=filters[0], size=(256, 256))
self.Up_conv3 = conv_block(ch_in=filters[2], ch_out=filters[1])
self.Up2 = up_conv(ch_in=filters[1], ch_out=filters[0])
self.Att2 = MsM(F_g=filters[0], F_l=filters[0], F_int=filters[0] // 2, size=(512, 512))
self.Up_conv2 = conv_block(ch_in=filters[1], ch_out=filters[0])
self.Conv_1x1 = nn.Conv2d(filters[0], num_classes, kernel_size=1, stride=1, padding=0)
def forward(self, x):
x1 = self.Conv1(x)
x2 = self.Maxpool(x1)
x2 = self.Conv2(x2)
x3 = self.Maxpool(x2)
x3 = self.Conv3(x3)
x4 = self.Maxpool(x3)
x4 = self.Conv4(x4)
x5 = self.Maxpool(x4)
x5 = self.Conv5(x5)
x5 = self.trans(x5, [x1, x2, x3, x4])
d5 = self.Up5(x5)
x4 = self.Att5(g=d5, x=x4)
d5 = torch.cat((x4, d5), dim=1)
d5 = self.Up_conv5(d5)
d4 = self.Up4(d5)
x3 = self.Att4(g=d4, x=x3)
d4 = torch.cat((x3, d4), dim=1)
d4 = self.Up_conv4(d4)
d3 = self.Up3(d4)
x2 = self.Att3(g=d3, x=x2)
d3 = torch.cat((x2, d3), dim=1)
d3 = self.Up_conv3(d3)
d2 = self.Up2(d3)
x1 = self.Att2(g=d2, x=x1)
d2 = torch.cat((x1, d2), dim=1)
d2 = self.Up_conv2(d2)
d1 = self.Conv_1x1(d2)
return torch.sigmoid(d1)
if __name__ == '__main__':
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
net = MsTGANet().to(device)
# 打印网络结构和参数
summary(net, (3, 512, 512))
MsTNL.py
import torch
import torch.nn as nn
class SPP_Q(nn.Module):
def __init__(self,in_ch,out_ch,down_scale,ks=3):
super(SPP_Q, self).__init__()
self.Conv = nn.Sequential(
nn.Conv2d(in_ch, out_ch, kernel_size=ks, stride=1, padding=ks // 2,bias=False),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True)
)
self.Down = nn.Upsample(scale_factor=down_scale,mode="bilinear")
def forward(self, x):
x_d = self.Down(x)
x_out = self.Conv(x_d)
return x_out
class Encoder_Pos(nn.Module):
def __init__(self, n_dims, width=32, height=32, filters=[32,64,128,256]):
super(Encoder_Pos, self).__init__()
print("================= Multi_Head_Encoder =================")
self.chanel_in = n_dims
self.rel_h = nn.Parameter(torch.randn([1, n_dims//8, height, 1]), requires_grad=True)
self.rel_w = nn.Parameter(torch.randn([1, n_dims//8, 1, width]), requires_grad=True)
self.SPP_Q_0 = SPP_Q(in_ch=filters[0],out_ch=n_dims,down_scale=1/16,ks=3)
self.SPP_Q_1 = SPP_Q(in_ch=filters[1],out_ch=n_dims,down_scale=1/8,ks=3)
self.SPP_Q_2 = SPP_Q(in_ch=filters[2],out_ch=n_dims,down_scale=1/4,ks=3)
self.SPP_Q_3 = SPP_Q(in_ch=filters[3],out_ch=n_dims,down_scale=1/2,ks=3)
self.query_conv = nn.Conv2d(in_channels = n_dims , out_channels = n_dims//8 , kernel_size= 1)
self.key_conv = nn.Conv2d(in_channels = n_dims , out_channels = n_dims//8 , kernel_size= 1)
self.value_conv = nn.Conv2d(in_channels = n_dims , out_channels = n_dims , kernel_size= 1)
self.gamma = nn.Parameter(torch.zeros(1))
self.softmax = nn.Softmax(dim=-1)
def forward(self, x,x_list):
m_batchsize, C, width, height = x.size()
Multi_X = self.SPP_Q_0(x_list[0]) + self.SPP_Q_1(x_list[1]) + self.SPP_Q_2(x_list[2]) + self.SPP_Q_3(x_list[3])
proj_query = self.query_conv(Multi_X).view(m_batchsize, -1, width * height).permute(0, 2, 1)
proj_key = self.key_conv(x).view(m_batchsize, -1, width * height)
energy_content = torch.bmm(proj_query, proj_key)
content_position = (self.rel_h + self.rel_w).view(1, self.chanel_in//8, -1)
content_position = torch.matmul(proj_query,content_position)
energy = energy_content + content_position
attention = self.softmax(energy)
proj_value = self.value_conv(x).view(m_batchsize, -1, width * height)
out = torch.bmm(proj_value, attention.permute(0, 2, 1))
out = out.view(m_batchsize, C, width, height)
out = self.gamma * out + x
return out, attention
class Decoder_Pos(nn.Module):
def __init__(self, n_dims, width=32, height=32):
super(Decoder_Pos, self).__init__()
print("================= Multi_Head_Decoder =================")
self.chanel_in = n_dims
self.rel_h = nn.Parameter(torch.randn([1, n_dims//8, height, 1]), requires_grad=True)
self.rel_w = nn.Parameter(torch.randn([1, n_dims//8, 1, width]), requires_grad=True)
self.query_conv = nn.Conv2d(in_channels=n_dims, out_channels=n_dims // 8, kernel_size=1)
self.key_conv = nn.Conv2d(in_channels=n_dims, out_channels=n_dims // 8, kernel_size=1)
self.value_conv = nn.Conv2d(in_channels=n_dims, out_channels=n_dims, kernel_size=1)
self.gamma = nn.Parameter(torch.zeros(1))
self.softmax = nn.Softmax(dim=-1)
def forward(self, x,x_encoder):
m_batchsize, C, width, height = x.size()
proj_query = self.query_conv(x).view(m_batchsize, -1, width * height).permute(0, 2, 1)
proj_key = self.key_conv(x_encoder).view(m_batchsize, -1, width * height)
energy_content = torch.bmm(proj_query, proj_key)
content_position = (self.rel_h + self.rel_w).view(1, self.chanel_in//8, -1)
content_position = torch.matmul(proj_query,content_position)
energy = energy_content+content_position
attention = self.softmax(energy)
proj_value = self.value_conv(x_encoder).view(m_batchsize, -1, width * height)
out = torch.bmm(proj_value, attention.permute(0, 2, 1))
out = out.view(m_batchsize, C, width, height)
out = self.gamma * out + x
return out, attention
class MsTNL(nn.Module):
def __init__(self,train_dim,filters=[32,64,128,256]):
print("============= MsTNL =============")
super(MsTNL, self).__init__()
self.encoder = Encoder_Pos(train_dim,width=32,height=32,filters=filters)
self.decoder = Decoder_Pos(train_dim,width=32,height=32)
def forward(self, x, x_list):
x_encoder,att_en = self.encoder(x, x_list)
x_out,att_de = self.decoder(x,x_encoder)
return x_out
MsGCS.py
import torch
import torch.nn as nn
class MsGCS(nn.Module):
def __init__(self,F_g,F_l,F_int,size):
super(MsGCS, self).__init__()
print("=============== MsGCS ===============")
self.Conv = nn.Sequential(
nn.Conv2d(F_g+F_l,F_int,kernel_size=1),
nn.BatchNorm2d(F_int),
nn.ReLU(inplace=True),
nn.Conv2d(F_int,1,kernel_size=1),
nn.BatchNorm2d(1),
nn.ReLU(inplace=True)
)
self.rel_h = nn.Parameter(torch.randn([1, 1, size[0], 1]), requires_grad=True)
self.rel_w = nn.Parameter(torch.randn([1, 1, 1, size[1]]), requires_grad=True)
self.bn = nn.BatchNorm2d(1)
self.active = nn.Sigmoid()
def forward(self,g,x):
x_Multi_Scale = self.Conv(torch.cat([g,x],dim=1))
content_position = self.rel_h+self.rel_w
x_att_multi_scale = self.active(self.bn(content_position*x_Multi_Scale))
return x_att_multi_scale*x