paper:HiFuse: Hierarchical multi-scale feature fusion network for medical image classification
1、Hierarchical Feature Fusion Module
医学图像分类任务中,需要同时提取局部空间特征和全局语义信息,而单一网络结构难以兼顾两者。现有融合方法如 ViTAE、StoHisNet 等,虽然取得了不错的效果,但在医学图像领域表现不佳,可能是因为医学图像数据量小、特征分散,难以有效学习。所以这篇论文提出一种 层次特征融合模块(Hierarchical Feature Fusion Module),HFF 模块通过以下机制实现特征融合:
- 空间注意力机制:增强局部特征中重要区域的表示,抑制无关信息。
- 通道注意力机制:增强全局特征中重要通道的表示,提高特征表达能力。
- 倒残差多层感知器 (IRMLP):学习融合后的特征,并通过深度可分离卷积降低计算成本。
- 短路连接:提高模型稳定性,防止梯度消失。
实现过程:
- 将局部特征和全局特征分别输入 HFF 模块。
- 对局部特征应用空间注意力机制,对全局特征应用通道注意力机制。
- 将注意力机制增强后的特征进行元素级乘法,得到加权后的特征。
- 对加权后的特征进行平均池化,得到融合后的特征表示。
- 将融合后的特征、加权后的局部特征和全局特征进行拼接。
- 将拼接后的特征输入 IRMLP 模块进行非线性变换,得到最终的融合特征。
优势:
- HFF 模块能够有效地融合不同层次、不同来源的特征,提高模型的表达能力。
- HFF 模块的设计简洁,易于理解和实现。
- HFF 模块能够适应不同尺度的特征,具有较强的泛化能力。
Hierarchical Feature Fusion Module 结构图:
2、代码实现
import torch
import torch.nn as nn
def drop_path_f(x, drop_prob = 0., training = False):
if drop_prob == 0. or not training:
return x
keep_prob = 1 - drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
random_tensor.floor_() # binarize
output = x.div(keep_prob) * random_tensor
return output
class DropPath(nn.Module):
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path_f(x, self.drop_prob, self.training)
class LayerNorm(nn.Module):
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
super().__init__()
self.weight = nn.Parameter(torch.ones(normalized_shape), requires_grad=True)
self.bias = nn.Parameter(torch.zeros(normalized_shape), requires_grad=True)
self.eps = eps
self.data_format = data_format
if self.data_format not in ["channels_last", "channels_first"]:
raise ValueError(f"not support data format '{self.data_format}'")
self.normalized_shape = (normalized_shape,)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.data_format == "channels_last":
return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
elif self.data_format == "channels_first":
# [batch_size, channels, height, width]
mean = x.mean(1, keepdim=True)
var = (x - mean).pow(2).mean(1, keepdim=True)
x = (x - mean) / torch.sqrt(var + self.eps)
x = self.weight[:, None, None] * x + self.bias[:, None, None]
return x
class HFF_block(nn.Module):
def __init__(self, ch_1, ch_2, ch_int, ch_out, r_2=4, drop_rate=0.):
super(HFF_block, self).__init__()
self.maxpool=nn.AdaptiveMaxPool2d(1)
self.avgpool=nn.AdaptiveAvgPool2d(1)
self.se=nn.Sequential(
nn.Conv2d(ch_2, ch_2 // r_2, 1,bias=False),
nn.ReLU(),
nn.Conv2d(ch_2 // r_2, ch_2, 1,bias=False)
)
self.sigmoid = nn.Sigmoid()
self.spatial = Conv(2, 1, 7, bn=True, relu=False, bias=False)
self.W_l = Conv(ch_1, ch_int, 1, bn=True, relu=False)
self.W_g = Conv(ch_2, ch_int, 1, bn=True, relu=False)
self.Avg = nn.AvgPool2d(2, stride=2)
self.Updim = Conv(ch_int//2, ch_int, 1, bn=True, relu=True)
self.norm1 = LayerNorm(ch_int * 3, eps=1e-6, data_format="channels_first")
self.norm2 = LayerNorm(ch_int * 2, eps=1e-6, data_format="channels_first")
self.norm3 = LayerNorm(ch_1 + ch_2 + ch_int, eps=1e-6, data_format="channels_first")
self.W3 = Conv(ch_int * 3, ch_int, 1, bn=True, relu=False)
self.W = Conv(ch_int * 2, ch_int, 1, bn=True, relu=False)
self.gelu = nn.GELU()
self.residual = IRMLP(ch_1 + ch_2 + ch_int, ch_out)
self.drop_path = DropPath(drop_rate) if drop_rate > 0. else nn.Identity()
def forward(self, l, g, f):
W_local = self.W_l(l) # local feature from Local Feature Block
W_global = self.W_g(g) # global feature from Global Feature Block
if f is not None:
W_f = self.Updim(f)
W_f = self.Avg(W_f)
shortcut = W_f
X_f = torch.cat([W_f, W_local, W_global], 1)
X_f = self.norm1(X_f)
X_f = self.W3(X_f)
X_f = self.gelu(X_f)
else:
shortcut = 0
X_f = torch.cat([W_local, W_global], 1)
X_f = self.norm2(X_f)
X_f = self.W(X_f)
X_f = self.gelu(X_f)
# spatial attention for ConvNeXt branch
l_jump = l
max_result, _ = torch.max(l, dim=1, keepdim=True)
avg_result = torch.mean(l, dim=1, keepdim=True)
result = torch.cat([max_result, avg_result], 1)
l = self.spatial(result)
l = self.sigmoid(l) * l_jump
# channel attetion for transformer branch
g_jump = g
max_result=self.maxpool(g)
avg_result=self.avgpool(g)
max_out=self.se(max_result)
avg_out=self.se(avg_result)
g = self.sigmoid(max_out+avg_out) * g_jump
fuse = torch.cat([g, l, X_f], 1)
fuse = self.norm3(fuse)
fuse = self.residual(fuse)
fuse = shortcut + self.drop_path(fuse)
return fuse
class Conv(nn.Module):
def __init__(self, inp_dim, out_dim, kernel_size=3, stride=1, bn=False, relu=True, bias=True, group=1):
super(Conv, self).__init__()
self.inp_dim = inp_dim
self.conv = nn.Conv2d(inp_dim, out_dim, kernel_size, stride, padding=(kernel_size-1)//2, bias=bias)
self.relu = None
self.bn = None
if relu:
self.relu = nn.ReLU(inplace=True)
if bn:
self.bn = nn.BatchNorm2d(out_dim)
def forward(self, x):
assert x.size()[1] == self.inp_dim, "{} {}".format(x.size()[1], self.inp_dim)
x = self.conv(x)
if self.bn is not None:
x = self.bn(x)
if self.relu is not None:
x = self.relu(x)
return x
#### Inverted Residual MLP
class IRMLP(nn.Module):
def __init__(self, inp_dim, out_dim):
super(IRMLP, self).__init__()
self.conv1 = Conv(inp_dim, inp_dim, 3, relu=False, bias=False, group=inp_dim)
self.conv2 = Conv(inp_dim, inp_dim * 4, 1, relu=False, bias=False)
self.conv3 = Conv(inp_dim * 4, out_dim, 1, relu=False, bias=False, bn=True)
self.gelu = nn.GELU()
self.bn1 = nn.BatchNorm2d(inp_dim)
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.gelu(out)
out += residual
out = self.bn1(out)
out = self.conv2(out)
out = self.gelu(out)
out = self.conv3(out)
return out
if __name__ == '__main__':
x = torch.randn(4, 64, 128, 128).cuda()
y = torch.randn(4, 64, 128, 128).cuda()
# z 可有可无
z = torch.randn(4, 32, 256, 256).cuda()
model = HFF_block(64, 64, 64, 64).cuda()
out = model(x, y, z)
print(out.shape)