Rethinking Boundary Detection in Deep Learning Models for Medical Image Segmentation, IPMI 2023
解读:IPMI 2023 港科大陈浩团队新作 | CTO: 重新思考边界检测在医学图像分割中的作用 (qq.com)
论文: https://arxiv.org/abs/2305.00678
代码: https://github.com/xiaofang007/CTO
介绍
本文提出一种新颖的网络架构CTO
,即Convolution
, Transformer
和 Operator
,通过结合卷积神经网络、视觉 Transformer 和显式边界检测操作,实现高精度的图像分割,并在准确性和效率之间保持最佳平衡。
CTO 遵循标准的编码器-解码器分割范式,其中编码器网络采用流行的 CNN 骨干结构来捕捉局部语义信息,并使用轻量级的 ViT 辅助网络来整合远距离依赖关系。为了增强边界的学习能力,本文进一步提出了一种基于边界引导的解码器网络,利用专用边界检测操作得到的边界掩模作为显式监督,引导解码学习过程。
Convolution, Transformer, and Operator (CTO)
CTO 遵循编码器-解码器范式,并采用跳跃连接将来自编码器的低级特征聚合到解码器中。其中编码器网络由主流的 CNN 和辅助 ViT 组成。解码器网络则采用边界检测运算符来指导其学习过程。
- 双流编码器,它结合了卷积神经网络和轻量级视觉 Transformer,分别捕捉图像局部特征依赖和图像块之间的远程特征依赖。
- 运算符引导的解码器,它使用边界检测运算符(例如
Sobel
)通过生成的边界掩模来指导学习过程,整个模型以端到端的方式进行训练。
Dual-Stream Encoder
CTO 首先构建一个卷积流,选择Res2Net
作为骨干网络,以捕捉局部特征依赖关系。
CTO使用一个基于轻量级Vision Transformer 的辅助流,捕捉不同图像块间的远程依赖关系。具体而言,它由多个并行的轻量级 Transformer 块组成,这些块接收不同尺度的特征块作为输入。所有的 Transformer 块共享相似的结构,包括块嵌入层和 Transformer 编码层。
LightViT 的块嵌入层用于将输入的特征块转换为嵌入向量,将空间维度转换为序列维度。Transformer 编码层用于对特征块进行自注意力机制的建模,以捕捉不同特征块之间的长程依赖关系。通过在 Transformer 模块中引入自注意力机制,LightViT 可以有效地对特征块之间的相互作用进行建模,从而提取图像的全局上下文信息。
Boundary-Guided Decoder
边界引导的解码器使用梯度运算符模块来提取前景对象的边界信息。然后,通过边界优化模块,将边界增强特征与多级编码器的特征进行整合,旨在同时在特征空间中表征类内和类间的一致性,丰富特征的表征能力。这种方法能够使解码器在生成分割结果时更好地利用边界信息,从而产生更准确的分割结果。
Boundary Enhanced Module (BEM)
边界优化模块使用高级特征和低级特征作为输入,提取边界信息并过滤掉与边界无关的信息。在水平方向Gx和垂直方向Gy上应用Sobel
算子来获得梯度图。具体而言,本文采用两个3*3的参数固定卷积,并应用步长为1的卷积操作。这两个卷积定义为:
然后,将这两个卷积应用于输入特征图,得到梯度图Mx和My。接下来,梯度图通过 sigmoid
函数进行归一化,然后与输入特征图融合,得到增强边缘特征图Fe:
其中,圈号表示逐元素相乘, 表示 sigmoid 函数,Mxy是将Mx和My沿通道维度进行拼接。然后,我们便可以直接使用简单的堆叠卷积层将边缘增强特征图进行融合。最后,输出特征图受到GT 边界图的监督,从而消除了物体内部的边缘特征,产生边界增强特征。
Boundary Inject Module (BIM)
通过 BEM 得到的边界增强特征可以作为先验知识,改善编码器生成的特征的图像表示能力。BIM,引入了双路径边界融合方案,促进前景和背景特征的表示能力。具体而言,BIM 接收两个输入:边界增强特征与来自编码器网络的对应特征的通道级连接,以及前一解码器层的特征。然后,这两个输入被馈送到 BIM 中,其中包含两个独立的路径,分别用于促进前景和背景的特征表示。
- 对于前景路径,我们直接沿通道维度将这两个输入进行拼接,然后应用一系列的 Conv-BN-ReLU(卷积、批归一化、ReLU激活)层,得到前景特征。
- 对于背景路径,则设计了背景注意力组件,选择性地关注背景信息。
前景路径得到前景特征Ffg。背景路径得到背景特征Fbg。
前景注意力图,由前一层解码器的特征图经过sigmoid得到;背景注意力图,由1减去前景注意力图得来。 最后,将前景特征Ffg、背景特征Fbg、前一层解码器特征拼接,得到本层输出。
Loss Function
CTO
是一个多任务模型,包含内部和边界分割,定义一个总体损失函数来共同优化这两个任务:
整体损失由主要的内部分割损失L_seg和边界损失L_bnd组成。在边界检测损失中,仅考虑来自 BEM 的预测结果,该模块将编码器的高层特征图和低层特征图作为输入。
Interior Segmentation Loss
L_seg是交叉熵损失L_CE和平均交并比 mIoU 损失L_mIoU的加权和:
Boundary Loss
边界损失 L_bnd考虑到边界检测中前景和背景像素之间的类别不平衡问题,因此采用Dice损失:
实验
关键代码
CTO_net.py
# https://github.com/xiaofang007/CTO/blob/main/CTOTrainer/network/CTO_net.py
class ConvBNR(nn.Module):
def __init__(self, inplanes, planes, kernel_size=3, stride=1, dilation=1, bias=False):
super(ConvBNR, self).__init__()
self.block = nn.Sequential(
nn.Conv2d(inplanes, planes, kernel_size, stride=stride, padding=dilation, dilation=dilation, bias=bias),
nn.BatchNorm2d(planes),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.block(x)
class Conv1x1(nn.Module):
def __init__(self, inplanes, planes):
super(Conv1x1, self).__init__()
self.conv = nn.Conv2d(inplanes, planes, 1)
self.bn = nn.BatchNorm2d(planes)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
return x
class EAM(nn.Module):
def __init__(self):
super(EAM, self).__init__()
self.reduce1 = Conv1x1(256, 64)
self.reduce4 = Conv1x1(512, 256)
self.block = nn.Sequential(
ConvBNR(320 + 64, 256, 3),
ConvBNR(256, 256, 3),
nn.Conv2d(256, 1, 1))
def forward(self, x1, x11, p2):
size = x1.size()[2:]
x1 = self.reduce1(x1)
x11 = self.reduce1(x11)
p2 = self.reduce4(p2)
p2 = F.interpolate(p2, size, mode='bilinear', align_corners=False)
out = torch.cat((x1, x11), dim=1)
out = torch.cat((out, p2), dim=1)
out = self.block(out)
return out
class EFM(nn.Module):
def __init__(self, channel):
super(EFM, self).__init__()
t = int(abs((log(channel, 2) + 1) / 2))
k = t if t % 2 else t + 1
self.conv2d = ConvBNR(channel, channel, 3)
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.conv1d = nn.Conv1d(1, 1, kernel_size=k, padding=(k - 1) // 2, bias=False)
self.sigmoid = nn.Sigmoid()
def forward(self, c, att):
if c.size() != att.size():
att = F.interpolate(att, c.size()[2:], mode='bilinear', align_corners=False)
x = c * att + c
x = self.conv2d(x)
wei = self.avg_pool(x)
wei = self.conv1d(wei.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
wei = self.sigmoid(wei)
x = x * wei
return x
class BasicConv2d(nn.Module):
def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1):
super(BasicConv2d, self).__init__()
self.conv = nn.Conv2d(in_planes, out_planes,
kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, bias=False)
self.bn = nn.BatchNorm2d(out_planes)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return x
class DM(nn.Module):
def __init__(self):
super(DM, self).__init__()
self.predict3 = nn.Sequential(
nn.Conv2d(128, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.PReLU(),
nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.PReLU(),
nn.Conv2d(64, 1, kernel_size=1)
)
self.ra2_conv2 = BasicConv2d(64, 64, kernel_size=3, padding=1)
self.ra2_conv3 = BasicConv2d(64, 64, kernel_size=3, padding=1)
self.ra2_conv4 = BasicConv2d(64, 1, kernel_size=3, padding=1)
def forward(self, xr, dualattention):
crop_3 = F.interpolate(dualattention, xr.size()[2:], mode='bilinear', align_corners=False)
re3_feat = self.predict3(torch.cat([xr, crop_3], dim=1))
x = -1*(torch.sigmoid(crop_3)) + 1
x = x.expand(-1, 64, -1, -1).mul(xr)
x = F.relu(self.ra2_conv2(x))
x = F.relu(self.ra2_conv3(x))
ra3_feat = self.ra2_conv4(x)
x = ra3_feat + crop_3 + re3_feat
return x
class _DAHead(nn.Module):
def __init__(self, in_channels, nclass, aux=True, norm_layer=nn.BatchNorm2d, norm_kwargs=None, **kwargs):
super(_DAHead, self).__init__()
self.aux = aux
inter_channels = in_channels // 4
self.conv_p1 = nn.Sequential(
nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False),
norm_layer(inter_channels, **({} if norm_kwargs is None else norm_kwargs)),
nn.ReLU(True)
)
self.conv_c1 = nn.Sequential(
nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False),
norm_layer(inter_channels, **({} if norm_kwargs is None else norm_kwargs)),
nn.ReLU(True)
)
self.pam = _PositionAttentionModule(inter_channels, **kwargs)
self.cam = _ChannelAttentionModule(**kwargs)
self.conv_p2 = nn.Sequential(
nn.Conv2d(inter_channels, inter_channels, 3, padding=1, bias=False),
norm_layer(inter_channels, **({} if norm_kwargs is None else norm_kwargs)),
nn.ReLU(True)
)
self.conv_c2 = nn.Sequential(
nn.Conv2d(inter_channels, inter_channels, 3, padding=1, bias=False),
norm_layer(inter_channels, **({} if norm_kwargs is None else norm_kwargs)),
nn.ReLU(True)
)
self.out = nn.Sequential(
nn.Dropout(0.1),
nn.Conv2d(inter_channels, nclass, 1)
)
if aux:
self.conv_p3 = nn.Sequential(
nn.Dropout(0.1),
nn.Conv2d(inter_channels, nclass, 1)
)
self.conv_c3 = nn.Sequential(
nn.Dropout(0.1),
nn.Conv2d(inter_channels, nclass, 1)
)
def forward(self, x):
feat_p = self.conv_p1(x)
feat_p = self.pam(feat_p)
feat_p = self.conv_p2(feat_p)
feat_c = self.conv_c1(x)
feat_c = self.cam(feat_c)
feat_c = self.conv_c2(feat_c)
feat_fusion = feat_p + feat_c
outputs = []
fusion_out = self.out(feat_fusion)
outputs.append(fusion_out)
if self.aux:
p_out = self.conv_p3(feat_p)
c_out = self.conv_c3(feat_c)
outputs.append(p_out)
outputs.append(c_out)
return tuple(outputs)
def run_sobel(conv_x, conv_y, input):
g_x = conv_x(input)
g_y = conv_y(input)
g = torch.sqrt(torch.pow(g_x, 2) + torch.pow(g_y, 2))
return torch.sigmoid(g) * input
def get_sobel(in_chan, out_chan):
'''
filter_x = np.array([
[3, 0, -3],
[10, 0, -10],
[3, 0, -3],
]).astype(np.float32)
filter_y = np.array([
[3, 10, 3],
[0, 0, 0],
[-3, -10, -3],
]).astype(np.float32)
'''
filter_x = np.array([
[1, 0, -1],
[2, 0, -2],
[1, 0, -1],
]).astype(np.float32)
filter_y = np.array([
[1, 2, 1],
[0, 0, 0],
[-1, -2, -1],
]).astype(np.float32)
filter_x = filter_x.reshape((1, 1, 3, 3))
filter_x = np.repeat(filter_x, in_chan, axis=1)
filter_x = np.repeat(filter_x, out_chan, axis=0)
filter_y = filter_y.reshape((1, 1, 3, 3))
filter_y = np.repeat(filter_y, in_chan, axis=1)
filter_y = np.repeat(filter_y, out_chan, axis=0)
filter_x = torch.from_numpy(filter_x)
filter_y = torch.from_numpy(filter_y)
filter_x = nn.Parameter(filter_x, requires_grad=False)
filter_y = nn.Parameter(filter_y, requires_grad=False)
conv_x = nn.Conv2d(in_chan, out_chan, kernel_size=3, stride=1, padding=1, bias=False)
conv_x.weight = filter_x
conv_y = nn.Conv2d(in_chan, out_chan, kernel_size=3, stride=1, padding=1, bias=False)
conv_y.weight = filter_y
sobel_x = nn.Sequential(conv_x, nn.BatchNorm2d(out_chan))
sobel_y = nn.Sequential(conv_y, nn.BatchNorm2d(out_chan))
return sobel_x, sobel_y
class GlobalFilter(nn.Module):
def __init__(self, dim=32, h=64, w=33, fp32fft=True):
super().__init__()
self.complex_weight = nn.Parameter(
torch.randn(h, w, dim, 2, dtype=torch.float32) * 0.02
)
self.w = w
self.h = h
self.fp32fft = fp32fft
def forward(self, x):
b, _, a, b = x.size()
x = x.permute(0, 2, 3, 1).contiguous()
if self.fp32fft:
dtype = x.dtype
x = x.to(torch.float32)
x = torch.fft.rfft2(x, dim=(1, 2), norm="ortho")
#print(x.shape)
weight = torch.view_as_complex(self.complex_weight)
# print(x.shape)
#print(weight.shape)
x = x * weight
x = torch.fft.irfft2(x, s=(a, b), dim=(1, 2), norm="ortho")
if self.fp32fft:
x = x.to(dtype)
x = x.permute(0, 3, 1, 2).contiguous()
return x
class ERB(nn.Module):
def __init__(self, in_channels, out_channels):
super(ERB, self).__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.relu = nn.ReLU()
self.bn = nn.BatchNorm2d(out_channels)
self.conv3 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
def forward(self, x, relu=True):
x = self.conv1(x)
res = self.conv2(x)
res = self.bn(res)
res = self.relu(res)
res = self.conv3(res)
if relu:
return self.relu(x + res)
else:
return x+res
class _PositionAttentionModule(nn.Module):
""" Position attention module"""
def __init__(self, in_channels, **kwargs):
super(_PositionAttentionModule, self).__init__()
self.conv_b = nn.Conv2d(in_channels, in_channels // 8, 1)
self.conv_c = nn.Conv2d(in_channels, in_channels // 8, 1)
self.conv_d = nn.Conv2d(in_channels, in_channels, 1)
self.alpha = nn.Parameter(torch.zeros(1))
self.softmax = nn.Softmax(dim=-1)
def forward(self, x):
batch_size, _, height, width = x.size()
feat_b = self.conv_b(x).view(batch_size, -1, height * width).permute(0, 2, 1)
feat_c = self.conv_c(x).view(batch_size, -1, height * width)
attention_s = self.softmax(torch.bmm(feat_b, feat_c))
feat_d = self.conv_d(x).view(batch_size, -1, height * width)
feat_e = torch.bmm(feat_d, attention_s.permute(0, 2, 1)).view(batch_size, -1, height, width)
out = self.alpha * feat_e + x
return out
class _ChannelAttentionModule(nn.Module):
"""Channel attention module"""
def __init__(self, **kwargs):
super(_ChannelAttentionModule, self).__init__()
self.beta = nn.Parameter(torch.zeros(1))
self.softmax = nn.Softmax(dim=-1)
def forward(self, x):
batch_size, _, height, width = x.size()
feat_a = x.view(batch_size, -1, height * width)
feat_a_transpose = x.view(batch_size, -1, height * width).permute(0, 2, 1)
attention = torch.bmm(feat_a, feat_a_transpose)
attention_new = torch.max(attention, dim=-1, keepdim=True)[0].expand_as(attention) - attention
attention = self.softmax(attention_new)
feat_e = torch.bmm(attention, feat_a).view(batch_size, -1, height, width)
out = self.beta * feat_e + x
return out
class EAM(nn.Module):
def __init__(self):
super(EAM, self).__init__()
self.reduce1 = Conv1x1(256, 64)
self.reduce4 = Conv1x1(2048, 256)
self.block = nn.Sequential(
ConvBNR(256 + 64, 256, 3),
ConvBNR(256, 256, 3),
nn.Conv2d(256, 1, 1))
def forward(self, x4, x1):
size = x1.size()[2:]
x1 = self.reduce1(x1)
x4 = self.reduce4(x4)
x4 = F.interpolate(x4, size, mode='bilinear', align_corners=False)
out = torch.cat((x4, x1), dim=1)
out = self.block(out)
return out
def attention(query, key, value):
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(
query.size(-1)
)
p_attn = F.softmax(scores, dim=-1)
p_val = torch.matmul(p_attn, value)
return p_val, p_attn
class MultiHeadedAttention(nn.Module):
"""
Take in model size and number of heads.
"""
def __init__(self, patchsize, d_model):
super().__init__()
self.patchsize = patchsize
self.query_embedding = nn.Conv2d(
d_model, d_model, kernel_size=1, padding=0
)
self.value_embedding = nn.Conv2d(
d_model, d_model, kernel_size=1, padding=0
)
self.key_embedding = nn.Conv2d(
d_model, d_model, kernel_size=1, padding=0
)
self.output_linear = nn.Sequential(
nn.Conv2d(d_model, d_model, kernel_size=3, padding=1),
nn.BatchNorm2d(d_model),
nn.LeakyReLU(0.2, inplace=True),
)
def forward(self, x):
b, c, h, w = x.size()#8,255,64,64
d_k = c // len(self.patchsize)
output = []
_query = self.query_embedding(x)#8,32,80,80
_key = self.key_embedding(x)#8,32,80,80
_value = self.value_embedding(x)#8,32,80,80
attentions = []
for (width, height), query, key, value in zip(
self.patchsize,
torch.chunk(_query, len(self.patchsize), dim=1),
torch.chunk(_key, len(self.patchsize), dim=1),
torch.chunk(_value, len(self.patchsize), dim=1),
):
#print('-----------width, height):',x.size())
# print('-----------x.size()):',x.size())
#print('-----------len(self.patchsize):',len(self.patchsize)) # 4
#print('-----------_query):',_query.shape) #8,256,64,64
#print('-----------query):',query.shape) #8,64,64,64
out_w, out_h = w // width, h // height#
## 1) embedding and reshape
query = query.view(b, d_k, out_h, height, out_w, width)
# print('-----------query):',query.shape)
# print('%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%')
query = (
query.permute(0, 2, 4, 1, 3, 5)
.contiguous()
.view(b, out_h * out_w, d_k * height * width)
)
key = key.view(b, d_k, out_h, height, out_w, width)
key = (
key.permute(0, 2, 4, 1, 3, 5)
.contiguous()
.view(b, out_h * out_w, d_k * height * width)
)
value = value.view(b, d_k, out_h, height, out_w, width)
value = (
value.permute(0, 2, 4, 1, 3, 5)
.contiguous()
.view(b, out_h * out_w, d_k * height * width)
)
y, _ = attention(query, key, value)
# 3) "Concat" using a view and apply a final linear.
y = y.view(b, out_h, out_w, d_k, height, width)
y = y.permute(0, 3, 1, 4, 2, 5).contiguous().view(b, d_k, h, w)
attentions.append(y)
output.append(y)
output = torch.cat(output, 1)
self_attention = self.output_linear(output)
return self_attention
class TransformerBlock(nn.Module):
"""
Transformer = MultiHead_Attention + Feed_Forward with sublayer connection
"""
def __init__(self, patchsize, in_channel=256):
super().__init__()
self.attention = MultiHeadedAttention(patchsize, d_model=in_channel)
self.feed_forward = FeedForward2D(
in_channel=in_channel, out_channel=in_channel
)
def forward(self, rgb):
self_attention = self.attention(rgb)
output = rgb + self_attention
output = output + self.feed_forward(output)
return output
class PatchTrans(BaseNetwork):
def __init__(self, in_channel, in_size):#32,80
super(PatchTrans, self).__init__()
self.in_size = in_size#80
patchsize = [
(32,32),#80,80
(16,16),#40,40
(8,8),#20,20
(4,4),#10,10
]
self.t = TransformerBlock(patchsize, in_channel=in_channel)
def forward(self, enc_feat):
output = self.t(enc_feat)
return output
class multi(nn.Module):
def __init__(self, channel):
super(EFM, self).__init__()
t = int(abs((log(channel, 2) + 1) / 2))
k = t if t % 2 else t + 1
self.conv2d = ConvBNR(channel, channel, 3)
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.conv1d = nn.Conv1d(1, 1, kernel_size=k, padding=(k - 1) // 2, bias=False)
self.sigmoid = nn.Sigmoid()
def forward(self, c, att):
if c.size() != att.size():
att = F.interpolate(att, c.size()[2:], mode='bilinear', align_corners=False)
x = c * att
#x = self.conv2d(x)
#wei = self.avg_pool(x)
#wei = self.conv1d(wei.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
#wei = self.sigmoid(wei)
#x = x * wei
return x
class CTO(nn.Module):
def __init__(self,seg_classes):
super(CTO, self).__init__()
self.resnet = res2net50_v1b_26w_4s(pretrained=True)
# if self.training:
# self.initialize_weights()
self.fft = GlobalFilter(dim = 3 , h=256, w=129, fp32fft= True)
self.multi_trans = PatchTrans(in_channel=256,in_size=64)
self.num_class = seg_classes
self.eam = EAM()
self.sobel_x1, self.sobel_y1 = get_sobel(256, 1)
self.sobel_x2, self.sobel_y2 = get_sobel(512, 1)
self.sobel_x3, self.sobel_y3 = get_sobel(1024, 1)
self.sobel_x4, self.sobel_y4 = get_sobel(2048, 1)
self.upsample = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
self.upsample_4 = nn.Upsample(scale_factor=4, mode="bilinear", align_corners=True)
self.upsample_3 = nn.Upsample(scale_factor=8, mode="bilinear", align_corners=True)
self.erb_db_1 = ERB(256, self.num_class)
self.erb_db_2 = ERB(512, self.num_class)
self.erb_db_3 = ERB(1024, self.num_class)
self.erb_db_4 = ERB(2048, self.num_class)
self.head = _DAHead(2048+256, 2048, aux=False)
self.reduce1 = Conv1x1(256, 64)
self.reduce2 = Conv1x1(512, 64)
self.reduce3 = Conv1x1(1024, 64)
self.reduce4 = Conv1x1(2048, 64)
self.reduce5 = Conv1x1(2048, 1)
self.dm1 = DM()
self.dm2 = DM()
self.dm3 = DM()
self.dm4 = DM()
self.predictor1 = nn.Conv2d(64, self.num_class, 1)
self.predictor2 = nn.Conv2d(64, self.num_class, 1)
self.predictor3 = nn.Conv2d(64, self.num_class, 1)
self.predictor4 = nn.Conv2d(64, self.num_class, 1)
# def initialize_weights(self):
# model_state = torch.load('./models/resnet50-19c8e357.pth')
# self.resnet.load_state_dict(model_state, strict=False)
def forward(self, x):
fft_fea = self.fft(x)#3,256,256
x1, x2, x3 ,x4= self.resnet(x)#[16, 256, 64, 64] [16, 512, 32, 32] [16, 1024, 16, 16] [16, 2048, 8, 8]
trans = self.multi_trans(x1)#16,256,64,64
s1 = run_sobel(self.sobel_x1, self.sobel_y1, x1)
s4 = run_sobel(self.sobel_x4, self.sobel_y4, x4)
edge = self.eam(s4, s1)
edge_att = torch.sigmoid(edge)#[16, 1, 64, 64]
trans = F.interpolate(trans,x4.size()[2:], mode='bilinear', align_corners=False)#256,8,8
dual_attention = self.head(torch.cat([trans, x4], dim=1))[0] #2048,8,8
x1a = x1*edge_att
edge_att2 = F.interpolate(edge_att, x2.size()[2:], mode='bilinear', align_corners=False)
x2a = x2*edge_att2
edge_att3 = F.interpolate(edge_att, x3.size()[2:], mode='bilinear', align_corners=False)
x3a = x3*edge_att3
#x1a = self.efm1(x1, edge_att)
#x2a = self.efm2(x2, edge_att)
# x3a = self.efm3(x3, edge_att)
# x4a = self.efm4(x4, edge_att)
x1r = self.reduce1(x1a)
x2r = self.reduce2(x2a)#128,32,32
x3r = self.reduce3(x3a)#256,16,16
dual_attention = self.reduce4(dual_attention)
c3 = self.dm3(x3r, dual_attention) #256 16 16
c2 = self.dm2(x2r, c3) #128 32 32
c1 = self.dm1(x1r, c2) #64 64 64
o3 = self.predictor3(c3)
o3 = F.interpolate(o3, scale_factor=16, mode='bilinear', align_corners=False)
o2 = self.predictor2(c2)
o2 = F.interpolate(o2, scale_factor=8, mode='bilinear', align_corners=False)
o1 = self.predictor1(c1)
o1 = F.interpolate(o1, scale_factor=4, mode='bilinear', align_corners=False)
oe = F.interpolate(edge_att, scale_factor=4, mode='bilinear', align_corners=False)
return o3, o2, o1, oe