论文链接:https://arxiv.org/pdf/2311.17791.pdf
代码链接:https://github.com/yaoppeng/U-Net_v2/blob/master/unet_v2/UNet_v2.py
def forward(self, x):
seg_outs = []
f1, f2, f3, f4 = self.encoder(x)
f1 = self.ca_1(f1) * f1
f1 = self.sa_1(f1) * f1
f1 = self.Translayer_1(f1)
f2 = self.ca_2(f2) * f2
f2 = self.sa_2(f2) * f2
f2 = self.Translayer_2(f2)
f3 = self.ca_3(f3) * f3
f3 = self.sa_3(f3) * f3
f3 = self.Translayer_3(f3)
f4 = self.ca_4(f4) * f4
f4 = self.sa_4(f4) * f4
f4 = self.Translayer_4(f4)
f41 = self.sdi_4([f1, f2, f3, f4], f4)
f31 = self.sdi_3([f1, f2, f3, f4], f3)
f21 = self.sdi_2([f1, f2, f3, f4], f2)
f11 = self.sdi_1([f1, f2, f3, f4], f1)
class SDI(nn.Module):
def __init__(self, channel):
super().__init__()
self.convs = nn.ModuleList(
[nn.Conv2d(channel, channel, kernel_size=3, stride=1, padding=1) for _ in range(4)])
def forward(self, xs, anchor):
ans = torch.ones_like(anchor)
target_size = anchor.shape[-1]
for i, x in enumerate(xs):
if x.shape[-1] > target_size:
x = F.adaptive_avg_pool2d(x, (target_size, target_size))
elif x.shape[-1] < target_size:
x = F.interpolate(x, size=(target_size, target_size),
mode='bilinear', align_corners=True)
ans = ans * self.convs[i](x)
return ans
过去的UNet在上采样的过程中每次通过拼接的方式复用一个stage的特征
这里则是每个stage都会通过哈达玛积的方式复用编码器中所有stage的特征
在复用前会对编码器每个stage 串联通道、空间注意力做增强