class Fusion(nn.Module):
def __init__(self, inc_list, fusion='bifpn') -> None:
super().__init__()
assert fusion in ['weight', 'adaptive', 'concat', 'bifpn']
self.fusion = fusion
if self.fusion == 'bifpn':
self.fusion_weight = nn.Parameter(torch.ones(len(inc_list), dtype=torch.float32), requires_grad=True)
self.relu = nn.ReLU()
self.epsilon = 1e-4
else:
self.fusion_conv = nn.ModuleList([Conv(inc, inc, 1) for inc in inc_list])
if self.fusion == 'adaptive':
self.fusion_adaptive = Conv(sum(inc_list), len(inc_list), 1)
def forward(self, x):
if self.fusion in ['weight', 'adaptive']:
for i in range(len(x)):
x[i] = self.fusion_conv[i](x[i])
if self.fusion == 'weight':
return torch.sum(torch.stack(x, dim=0), dim=0)
elif self.fusion == 'adaptive':
fusion = torch.softmax(self.fusion_adaptive(torch.cat(x, dim=1)), dim=1)
x_weight = torch.split(fusion, [1] * len(x), dim=1)
weighted_sum = sum([x_weight[i] * x[i] for i in range(len(x))])
return weighted_sum
elif self.fusion == 'concat':
return torch.cat(x, dim=1)
elif self.fusion == 'bifpn':
fusion_weight = self.relu(self.fusion_weight.clone())
fusion_weight = fusion_weight / (torch.sum(fusion_weight, dim=0))
weighted_sum = sum([fusion_weight[i] * x[i] for i in range(len(x))])
return weighted_sum
06-28
04-18
08-31
08-28
03-14
“相关推荐”对你有帮助么?
-
非常没帮助
-
没帮助
-
一般
-
有帮助
-
非常有帮助
提交