步骤:
1. 设计多尺度模块,使用不同卷积核或扩张率提取特征。
2. 设计自注意力模块,计算特征图的空间注意力。
3. 将多尺度模块和自注意力模块整合到主干网络中。
4. 构建完整的分类模型,包括预处理、主干、多尺度+自注意力、分类头。
5. 测试模型各部分是否正确,输出形状是否符合预期。
6. 提供使用示例,包括数据加载和训练步骤。
以下是一个结合多尺度特征融合和Self-Attention机制的PyTorch分类网络实现,包含详细的模块设计和完整网络架构:
import torch
import torch.nn as nn
import torch.nn.functional as F
class MultiScaleBlock(nn.Module):
"""多尺度特征提取模块"""
def __init__(self, in_channels, out_channels):
super().__init__()
# 分支1:3x3卷积
self.branch1 = nn.Sequential(
nn.Conv2d(in_channels, out_channels//4, 3, padding=1),
nn.BatchNorm2d(out_channels//4),
nn.ReLU(inplace=True)
)
# 分支2:5x5卷积
self.branch2 = nn.Sequential(
nn.Conv2d(in_channels, out_channels//4, 5, padding=2),
nn.BatchNorm2d(out_channels//4),
nn.ReLU(inplace=True)
)
# 分支3:空洞卷积(扩张率=2)
self.branch3 = nn.Sequential(
nn.Conv2d(in_channels, out_channels//4, 3,
padding=2, dilation=2),
nn.BatchNorm2d(out_channels//4),
nn.ReLU(inplace=True)
)
# 分支4:全局上下文
self.branch4 = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(in_channels, out_channels//4, 1),
nn.BatchNorm2d(out_channels//4),
nn.ReLU(inplace=True)
)
# 特征融合后的卷积
self.fusion = nn.Sequential(
nn.Conv2d(out_channels, out_channels, 1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
b1 = self.branch1(x)
b2 = self.branch2(x)
b3 = self.branch3(x)
b4 = F.interpolate(self.branch4(x), size=x.shape[2:],
mode='bilinear', align_corners=False)
# 通道维度拼接
out = torch.cat([b1, b2, b3, b4], dim=1)
return self.fusion(out)
class SelfAttention(nn.Module):
"""自注意力机制模块"""
def __init__(self, in_channels):
super().__init__()
self.query = nn.Conv2d(in_channels, in_channels//8, 1)
self.key = nn.Conv2d(in_channels, in_channels//8, 1)
self.value = nn.Conv2d(in_channels, in_channels, 1)
self.gamma = nn.Parameter(torch.zeros(1)) # 可学习权重系数
def forward(self, x):
batch_size, C, H, W = x.size()
# 计算Q, K, V
Q = self.query(x).view(batch_size, -1, H*W).permute(0, 2, 1) # [B, N, C']
K = self.key(x).view(batch_size, -1, H*W) # [B, C', N]
V = self.value(x).view(batch_size, -1, H*W) # [B, C, N]
# 计算注意力图
attention = torch.bmm(Q, K) # [B, N, N]
attention = F.softmax(attention, dim=-1)
# 特征聚合
out = torch.bmm(V, attention.permute(0, 2, 1)) # [B, C, N]
out = out.view(batch_size, C, H, W)
return self.gamma * out + x # 残差连接
class MSANet(nn.Module):
"""多尺度+自注意力分类网络"""
def __init__(self, num_classes=1000):
super().__init__()
# 特征提取主干
self.stem = nn.Sequential(
nn.Conv2d(3, 64, 7, stride=2, padding=3),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.MaxPool2d(3, stride=2, padding=1)
)
# 网络阶段
self.stage1 = self._make_stage(64, 128, num_blocks=2)
self.stage2 = self._make_stage(128, 256, num_blocks=3)
self.stage3 = self._make_stage(256, 512, num_blocks=3)
# 分类头
self.head = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Flatten(),
nn.Linear(512, 256),
nn.ReLU(inplace=True),
nn.Dropout(0.5),
nn.Linear(256, num_classes)
)
def _make_stage(self, in_channels, out_channels, num_blocks):
"""构建网络阶段"""
layers = []
# 下采样模块
layers.append(nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, stride=2, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
))
# 添加多个多尺度+注意力模块
for _ in range(num_blocks):
layers.append(MultiScaleBlock(out_channels, out_channels))
layers.append(SelfAttention(out_channels))
return nn.Sequential(*layers)
def forward(self, x):
x = self.stem(x) # 初始特征提取
x = self.stage1(x) # 阶段1
x = self.stage2(x) # 阶段2
x = self.stage3(x) # 阶段3
return self.head(x) # 分类输出
if __name__ == "__main__":
# 验证网络结构
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = MSANet(num_classes=1000).to(device)
# 打印参数量
total_params = sum(p.numel() for p in model.parameters())
print(f"Total parameters: {total_params/1e6:.2f}M")
# 测试前向传播
dummy_input = torch.randn(2, 3, 224, 224).to(device)
output = model(dummy_input)
print(f"Output shape: {output.shape}") # 应为 torch.Size([2, 1000])
网络结构说明
1)多尺度模块 (MultiScaleBlock)
并行四个分支:
3x3常规卷积 5x5大核卷积 3x3空洞卷积(扩张率2)
全局平均池化+1x1卷积 特征融合后使用1x1卷积调整通道数
2)自注意力模块 (SelfAttention)
使用QKV机制计算空间注意力
包含可学习的残差权重系数
计算复杂度通过通道压缩优化
3)整体网络架构 (MSANet) Stem:
初始特征提取(7x7卷积 + MaxPool)
3个Stage:每个Stage包含下采样和多尺度-注意力模块堆叠
分类头:全局平均池化 + 全连接层
主要特点
1)多尺度特征融合
同时捕获局部细节和全局上下文
不同感受野的特征互补
空洞卷积扩大感受野不增加参数量
2)注意力机制
增强重要特征区域的响应
建立长距离依赖关系
残差设计保证训练稳定性
3)高效结构设计
通道数渐进增加
下采样集中在特定阶段
参数量约25M(可根据需求调整)
使用示例
# 初始化模型
model = MSANet(num_classes=10)
# 训练配置
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
# 数据加载
train_loader = torch.utils.data.DataLoader(
dataset, batch_size=32, shuffle=True)
# 训练循环
for epoch in range(100):
for inputs, labels in train_loader:
outputs = model(inputs)
loss = criterion(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()