Pytorch编写一个融合多尺度和Self-Attention的分类网络

步骤:

1. 设计多尺度模块,使用不同卷积核或扩张率提取特征。

2. 设计自注意力模块,计算特征图的空间注意力。

3. 将多尺度模块和自注意力模块整合到主干网络中。

4. 构建完整的分类模型,包括预处理、主干、多尺度+自注意力、分类头。

5. 测试模型各部分是否正确,输出形状是否符合预期。

6. 提供使用示例,包括数据加载和训练步骤。

以下是一个结合多尺度特征融合和Self-Attention机制的PyTorch分类网络实现,包含详细的模块设计和完整网络架构:

import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiScaleBlock(nn.Module):
    """多尺度特征提取模块"""
    def __init__(self, in_channels, out_channels):
        super().__init__()
        
        # 分支1:3x3卷积
        self.branch1 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels//4, 3, padding=1),
            nn.BatchNorm2d(out_channels//4),
            nn.ReLU(inplace=True)
        )
        
        # 分支2:5x5卷积
        self.branch2 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels//4, 5, padding=2),
            nn.BatchNorm2d(out_channels//4),
            nn.ReLU(inplace=True)
        )
        
        # 分支3:空洞卷积(扩张率=2)
        self.branch3 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels//4, 3, 
                      padding=2, dilation=2),
            nn.BatchNorm2d(out_channels//4),
            nn.ReLU(inplace=True)
        )
        
        # 分支4:全局上下文
        self.branch4 = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(in_channels, out_channels//4, 1),
            nn.BatchNorm2d(out_channels//4),
            nn.ReLU(inplace=True)
        )
        
        # 特征融合后的卷积
        self.fusion = nn.Sequential(
            nn.Conv2d(out_channels, out_channels, 1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        b1 = self.branch1(x)
        b2 = self.branch2(x)
        b3 = self.branch3(x)
        b4 = F.interpolate(self.branch4(x), size=x.shape[2:], 
                          mode='bilinear', align_corners=False)
        
        # 通道维度拼接
        out = torch.cat([b1, b2, b3, b4], dim=1)
        return self.fusion(out)

class SelfAttention(nn.Module):
    """自注意力机制模块"""
    def __init__(self, in_channels):
        super().__init__()
        self.query = nn.Conv2d(in_channels, in_channels//8, 1)
        self.key = nn.Conv2d(in_channels, in_channels//8, 1)
        self.value = nn.Conv2d(in_channels, in_channels, 1)
        self.gamma = nn.Parameter(torch.zeros(1))  # 可学习权重系数

    def forward(self, x):
        batch_size, C, H, W = x.size()
        
        # 计算Q, K, V
        Q = self.query(x).view(batch_size, -1, H*W).permute(0, 2, 1)  # [B, N, C']
        K = self.key(x).view(batch_size, -1, H*W)  # [B, C', N]
        V = self.value(x).view(batch_size, -1, H*W)  # [B, C, N]

        # 计算注意力图
        attention = torch.bmm(Q, K)  # [B, N, N]
        attention = F.softmax(attention, dim=-1)

        # 特征聚合
        out = torch.bmm(V, attention.permute(0, 2, 1))  # [B, C, N]
        out = out.view(batch_size, C, H, W)
        
        return self.gamma * out + x  # 残差连接

class MSANet(nn.Module):
    """多尺度+自注意力分类网络"""
    def __init__(self, num_classes=1000):
        super().__init__()
        
        # 特征提取主干
        self.stem = nn.Sequential(
            nn.Conv2d(3, 64, 7, stride=2, padding=3),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(3, stride=2, padding=1)
        )
        
        # 网络阶段
        self.stage1 = self._make_stage(64, 128, num_blocks=2)
        self.stage2 = self._make_stage(128, 256, num_blocks=3)
        self.stage3 = self._make_stage(256, 512, num_blocks=3)
        
        # 分类头
        self.head = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(512, 256),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(256, num_classes)
        )

    def _make_stage(self, in_channels, out_channels, num_blocks):
        """构建网络阶段"""
        layers = []
        # 下采样模块
        layers.append(nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, stride=2, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        ))
        
        # 添加多个多尺度+注意力模块
        for _ in range(num_blocks):
            layers.append(MultiScaleBlock(out_channels, out_channels))
            layers.append(SelfAttention(out_channels))
        
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.stem(x)       # 初始特征提取
        x = self.stage1(x)     # 阶段1
        x = self.stage2(x)    # 阶段2 
        x = self.stage3(x)     # 阶段3
        return self.head(x)    # 分类输出

if __name__ == "__main__":
    # 验证网络结构
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = MSANet(num_classes=1000).to(device)
    
    # 打印参数量
    total_params = sum(p.numel() for p in model.parameters())
    print(f"Total parameters: {total_params/1e6:.2f}M")
    
    # 测试前向传播
    dummy_input = torch.randn(2, 3, 224, 224).to(device)
    output = model(dummy_input)
    print(f"Output shape: {output.shape}")  # 应为 torch.Size([2, 1000])

网络结构说明

1)多尺度模块 (MultiScaleBlock)

并行四个分支:

3x3常规卷积 5x5大核卷积 3x3空洞卷积(扩张率2)

全局平均池化+1x1卷积 特征融合后使用1x1卷积调整通道数

2)自注意力模块 (SelfAttention)

使用QKV机制计算空间注意力

包含可学习的残差权重系数

计算复杂度通过通道压缩优化

3)整体网络架构 (MSANet) Stem:

初始特征提取(7x7卷积 + MaxPool)

3个Stage:每个Stage包含下采样和多尺度-注意力模块堆叠

分类头:全局平均池化 + 全连接层

主要特点

1)多尺度特征融合

同时捕获局部细节和全局上下文

不同感受野的特征互补

空洞卷积扩大感受野不增加参数量

2)注意力机制

增强重要特征区域的响应

建立长距离依赖关系

残差设计保证训练稳定性

3)高效结构设计

通道数渐进增加

下采样集中在特定阶段

参数量约25M(可根据需求调整)

使用示例

# 初始化模型
model = MSANet(num_classes=10)

# 训练配置
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

# 数据加载
train_loader = torch.utils.data.DataLoader(
    dataset, batch_size=32, shuffle=True)

# 训练循环
for epoch in range(100):
    for inputs, labels in train_loader:
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值