动态适应输入的多尺度融合结构,代码如下:
![在这里插入图片描述](https://i-blog.csdnimg.cn/blog_migrate/3937c087eaea8a949e482c57a7ade760.png)
class DCMLayer(nn.Module):
def __init__(self, k, channel):
super(DCMLayer, self).__init__()
self.k = k
self.channel = channel
self.conv = nn.Conv2d(channel, channel // 4, 1, padding=0, bias=True)
self.fuse = nn.Conv2d(channel // 4, channel, 1, padding=0, bias=True)
self.dw_conv = nn.Conv2d(channel // 4, channel // 4, self.k, padding=(self.k-1) // 2, groups=channel // 4)
self.pool = nn.AdaptiveAvgPool2d(k)
def forward(self, x):
N, C, H, W = x.shape
f = self.conv(x)
g = self.conv(self.pool(x))
f_list = torch.split(f, 1, 0)
g_list = torch.split(g, 1, 0)
out = []
for i in range(N):
f_one = f_list[i]
g_one = g_list[i].squeeze(0).unsqueeze(1)
self.dw_conv.weight = nn.Parameter(g_one)
o = self.dw_conv(f_one)
out.append(o)
y = torch.cat(out, dim=0)
y = self.fuse(y)
return y
class DCM(nn.Module):
def __init__(self, channel):
super(DCM, self).__init__()
self.DCM1 = DCMLayer(1, channel)
self.DCM3 = DCMLayer(3, channel)
self.DCM5 = DCMLayer(5, channel)
self.conv = nn.Conv2d(channel * 4, channel, 1, padding=0, bias=True)
def forward(self, x):
dcm1 = self.DCM1(x)
dcm3 = self.DCM3(x)
dcm5 = self.DCM5(x)
out = torch.cat([x, dcm1, dcm3, dcm5], dim=1)
out = self.conv(out)
return out