Pytorch基于Fasternet的改进式网络的图像分类

一、Fasternet概述

Fasternet提出部分卷积是快速有效的,只在少数输入通道上应用滤波器,而不影响其余通道。同时提出了一种简单的PConv,减少计算冗余和内存份额访问,获得比常规卷积更低的FLOP和比深度方向/组卷积更高的FLOP。

FatserNet的结构图如下:可以发现在每个分层阶段都有一个FasterNet块堆栈,最前面一个嵌入层,每层后面接一个合并层。为了充分有效地利用来自所有通道的信息,作者把PWConv接在PConv上,这样做的好处是一方面它更关注中心位置,另一方面它类似于T形卷积,但是FLOP小于T形卷积。

文章中的卷积与T形卷积的对比:

二、训练效果

训练过程:为看到效果,未使用数据增强;数据预处理如下:

transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])

    ])
transform_test = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std= [0.5, 0.5, 0.5])
    ])

使用简单的优化器和余弦退火策略:

optimizer = optim.AdamW(model_ft.parameters(),lr=model_lr)
cosine_schedule = optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=30, eta_min=1e-7)

训练参数如下:

model_lr = 1e-4
BATCH_SIZE = 32
EPOCHS = 100
classes = 3

训练效果:

训练集acc最高达99.16%,验证集acc最高达98.91%。

计算该模型性能参数:Params,FLOPs,Throughputs

import models.fasternet as fasternet
import torch
from torchsummary import summary
model = fasternet()

input_tensor = torch.randn(1, 3, 224, 224)
input_tensor = input_tensor.cuda()
model = model.cuda()
print(summary(model,(3,224,224)))

使用上述代码得到Params:

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1           [-1, 96, 56, 56]           4,608
       BatchNorm2d-2           [-1, 96, 56, 56]             192
        PatchEmbed-3           [-1, 96, 56, 56]               0
            Conv2d-4           [-1, 24, 56, 56]           5,184
     Partial_conv3-5           [-1, 96, 56, 56]               0
            Conv2d-6          [-1, 192, 56, 56]          18,432
       BatchNorm2d-7          [-1, 192, 56, 56]             384
              ReLU-8          [-1, 192, 56, 56]               0
            Conv2d-9           [-1, 96, 56, 56]          18,432
         Identity-10           [-1, 96, 56, 56]               0
         MLPBlock-11           [-1, 96, 56, 56]               0
       BasicStage-12           [-1, 96, 56, 56]               0
           Conv2d-13          [-1, 192, 28, 28]          73,728
      BatchNorm2d-14          [-1, 192, 28, 28]             384
     PatchMerging-15          [-1, 192, 28, 28]               0
           Conv2d-16           [-1, 48, 28, 28]          20,736
    Partial_conv3-17          [-1, 192, 28, 28]               0
           Conv2d-18          [-1, 384, 28, 28]          73,728
      BatchNorm2d-19          [-1, 384, 28, 28]             768
             ReLU-20          [-1, 384, 28, 28]               0
           Conv2d-21          [-1, 192, 28, 28]          73,728
         DropPath-22          [-1, 192, 28, 28]               0
         MLPBlock-23          [-1, 192, 28, 28]               0
           Conv2d-24           [-1, 48, 28, 28]          20,736
    Partial_conv3-25          [-1, 192, 28, 28]               0
           Conv2d-26          [-1, 384, 28, 28]          73,728
      BatchNorm2d-27          [-1, 384, 28, 28]             768
             ReLU-28          [-1, 384, 28, 28]               0
           Conv2d-29          [-1, 192, 28, 28]          73,728
         DropPath-30          [-1, 192, 28, 28]               0
         MLPBlock-31          [-1, 192, 28, 28]               0
       BasicStage-32          [-1, 192, 28, 28]               0
           Conv2d-33          [-1, 384, 14, 14]         294,912
      BatchNorm2d-34          [-1, 384, 14, 14]             768
     PatchMerging-35          [-1, 384, 14, 14]               0
           Conv2d-36           [-1, 96, 14, 14]          82,944
    Partial_conv3-37          [-1, 384, 14, 14]               0
           Conv2d-38          [-1, 768, 14, 14]         294,912
      BatchNorm2d-39          [-1, 768, 14, 14]           1,536
             ReLU-40          [-1, 768, 14, 14]               0
           Conv2d-41          [-1, 384, 14, 14]         294,912
         DropPath-42          [-1, 384, 14, 14]               0
         MLPBlock-43          [-1, 384, 14, 14]               0
           Conv2d-44           [-1, 96, 14, 14]          82,944
    Partial_conv3-45          [-1, 384, 14, 14]               0
           Conv2d-46          [-1, 768, 14, 14]         294,912
      BatchNorm2d-47          [-1, 768, 14, 14]           1,536
             ReLU-48          [-1, 768, 14, 14]               0
           Conv2d-49          [-1, 384, 14, 14]         294,912
         DropPath-50          [-1, 384, 14, 14]               0
         MLPBlock-51          [-1, 384, 14, 14]               0
           Conv2d-52           [-1, 96, 14, 14]          82,944
    Partial_conv3-53          [-1, 384, 14, 14]               0
           Conv2d-54          [-1, 768, 14, 14]         294,912
      BatchNorm2d-55          [-1, 768, 14, 14]           1,536
             ReLU-56          [-1, 768, 14, 14]               0
           Conv2d-57          [-1, 384, 14, 14]         294,912
         DropPath-58          [-1, 384, 14, 14]               0
         MLPBlock-59          [-1, 384, 14, 14]               0
           Conv2d-60           [-1, 96, 14, 14]          82,944
    Partial_conv3-61          [-1, 384, 14, 14]               0
           Conv2d-62          [-1, 768, 14, 14]         294,912
      BatchNorm2d-63          [-1, 768, 14, 14]           1,536
             ReLU-64          [-1, 768, 14, 14]               0
           Conv2d-65          [-1, 384, 14, 14]         294,912
         DropPath-66          [-1, 384, 14, 14]               0
         MLPBlock-67          [-1, 384, 14, 14]               0
           Conv2d-68           [-1, 96, 14, 14]          82,944
    Partial_conv3-69          [-1, 384, 14, 14]               0
           Conv2d-70          [-1, 768, 14, 14]         294,912
      BatchNorm2d-71          [-1, 768, 14, 14]           1,536
             ReLU-72          [-1, 768, 14, 14]               0
           Conv2d-73          [-1, 384, 14, 14]         294,912
         DropPath-74          [-1, 384, 14, 14]               0
         MLPBlock-75          [-1, 384, 14, 14]               0
           Conv2d-76           [-1, 96, 14, 14]          82,944
    Partial_conv3-77          [-1, 384, 14, 14]               0
           Conv2d-78          [-1, 768, 14, 14]         294,912
      BatchNorm2d-79          [-1, 768, 14, 14]           1,536
             ReLU-80          [-1, 768, 14, 14]               0
           Conv2d-81          [-1, 384, 14, 14]         294,912
         DropPath-82          [-1, 384, 14, 14]               0
         MLPBlock-83          [-1, 384, 14, 14]               0
           Conv2d-84           [-1, 96, 14, 14]          82,944
    Partial_conv3-85          [-1, 384, 14, 14]               0
           Conv2d-86          [-1, 768, 14, 14]         294,912
      BatchNorm2d-87          [-1, 768, 14, 14]           1,536
             ReLU-88          [-1, 768, 14, 14]               0
           Conv2d-89          [-1, 384, 14, 14]         294,912
         DropPath-90          [-1, 384, 14, 14]               0
         MLPBlock-91          [-1, 384, 14, 14]               0
           Conv2d-92           [-1, 96, 14, 14]          82,944
    Partial_conv3-93          [-1, 384, 14, 14]               0
           Conv2d-94          [-1, 768, 14, 14]         294,912
      BatchNorm2d-95          [-1, 768, 14, 14]           1,536
             ReLU-96          [-1, 768, 14, 14]               0
           Conv2d-97          [-1, 384, 14, 14]         294,912
         DropPath-98          [-1, 384, 14, 14]               0
         MLPBlock-99          [-1, 384, 14, 14]               0
      BasicStage-100          [-1, 384, 14, 14]               0
          Conv2d-101            [-1, 768, 7, 7]       1,179,648
     BatchNorm2d-102            [-1, 768, 7, 7]           1,536
    PatchMerging-103            [-1, 768, 7, 7]               0
          Conv2d-104            [-1, 192, 7, 7]         331,776
   Partial_conv3-105            [-1, 768, 7, 7]               0
          Conv2d-106           [-1, 1536, 7, 7]       1,179,648
     BatchNorm2d-107           [-1, 1536, 7, 7]           3,072
            ReLU-108           [-1, 1536, 7, 7]               0
          Conv2d-109            [-1, 768, 7, 7]       1,179,648
        DropPath-110            [-1, 768, 7, 7]               0
        MLPBlock-111            [-1, 768, 7, 7]               0
          Conv2d-112            [-1, 192, 7, 7]         331,776
   Partial_conv3-113            [-1, 768, 7, 7]               0
          Conv2d-114           [-1, 1536, 7, 7]       1,179,648
     BatchNorm2d-115           [-1, 1536, 7, 7]           3,072
            ReLU-116           [-1, 1536, 7, 7]               0
          Conv2d-117            [-1, 768, 7, 7]       1,179,648
        DropPath-118            [-1, 768, 7, 7]               0
        MLPBlock-119            [-1, 768, 7, 7]               0
      BasicStage-120            [-1, 768, 7, 7]               0
AdaptiveAvgPool2d-121            [-1, 768, 1, 1]               0
          Conv2d-122           [-1, 1280, 1, 1]         983,040
            ReLU-123           [-1, 1280, 1, 1]               0
          Linear-124                 [-1, 1000]       1,281,000
================================================================
Total params: 14,982,888
Trainable params: 14,982,888
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 117.32
Params size (MB): 57.16
Estimated Total Size (MB): 175.05
import models.fasternet as fasternet
import torch
from thop import profile

# 创建一个模型实例
model = fasternet()
# 创建一个输入张量
input_tensor = torch.randn(1, 3, 224, 224)
# 使用thop库的profile函数计算模型的FLOPs
flops, params = profile(model, inputs=(input_tensor,))
# 打印模型的FLOPs
print(f"FLOPs: {flops}")

使用上述代码得到FLOPs为 1916867072.0。

#以100为最优批次大小,计算吞吐量
import time
import torch
import models.fasternet as fasternet
# 创建一个输入张量
input_tensor = torch.randn(1, 3, 224, 224)
model = fasternet()
# 设置迭代次数
num_iterations = 100
# 开始计时
start_time = time.time()
# 执行模型的前向传播多次
for _ in range(num_iterations):
    output = model(input_tensor)
# 结束计时
end_time = time.time()
# 计算总时间
total_time = end_time - start_time
# 计算吞吐量
throughput = num_iterations / total_time
# 打印设备的吞吐量
print(f"Throughput: {throughput} ops/s")

根据上述代码得出吞吐量为22.52407885262045 ops/s。

三、改进Fasternet

1、尝试调整Patch Embedding层的卷积核大小和步长并引入Efficient Channel Attention以改进Fasternet,当前的卷积核大小为(4, 4),步长为(4, 4),这可能导致在输入图像的分辨率较高时,丢失了一些细节信息。我会尝试减小卷积核大小和步长,以增加输出特征图的分辨率,从而更好地保留图像的细节信息。考虑将ECA添加进Patch Embedding层,更好地捕捉输入图像的细节和通道之间的关系。

self.patch_embed = PatchEmbed(
    proj=nn.Sequential(
        nn.Conv2d(3, 96, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False),
        nn.BatchNorm2d(96),
        nn.ReLU(inplace=True),
        eca_block(96)  # 添加ECA注意力机制
    ),
    norm=self.norm
)

其中ECA的具体实现如下:

class eca_block(nn.Module):
    # 初始化, in_channel代表特征图的输入通道数, b和gama代表公式中的两个系数
    def __init__(self, in_channel, b=1, gama=2):
        # 继承父类初始化
        super(eca_block, self).__init__()

        # 根据输入通道数自适应调整卷积核大小
        kernel_size = int(abs((math.log(in_channel, 2) + b) / gama))
        # 如果卷积核大小是奇数,就使用它
        if kernel_size % 2:
            kernel_size = kernel_size
        # 如果卷积核大小是偶数,就把它变成奇数
        else:
            kernel_size = kernel_size

        # 卷积时,为例保证卷积前后的size不变,需要0填充的数量
        padding = kernel_size // 2

        # 全局平均池化,输出的特征图的宽高=1
        self.avg_pool = nn.AdaptiveAvgPool2d(output_size=1)
        # 1D卷积,输入和输出通道数都=1,卷积核大小是自适应的
        self.conv = nn.Conv1d(in_channels=1, out_channels=1, kernel_size=kernel_size,
                              bias=False, padding=padding)
        # sigmoid激活函数,权值归一化
        self.sigmoid = nn.Sigmoid()

    # 前向传播
    def forward(self, inputs):
        # 获得输入图像的shape
        b, c, h, w = inputs.shape

        # 全局平均池化 [b,c,h,w]==>[b,c,1,1]
        x = self.avg_pool(inputs)
        # 维度调整,变成序列形式 [b,c,1,1]==>[b,1,c]
        x = x.view([b, 1, c])
        # 1D卷积 [b,1,c]==>[b,1,c]
        x = self.conv(x)
        # 权值归一化
        x = self.sigmoid(x)
        # 维度调整 [b,1,c]==>[b,c,1,1]
        x = x.view([b, c, 1, 1])

        # 将输入特征图和通道权重相乘[b,c,h,w]*[b,c,1,1]==>[b,c,h,w]
        outputs = x * inputs
        return outputs

2、由于我的任务分类数较少,考虑减少网络的深度或宽度,并在BasicStage的MLPBlock中引入SE(Squeeze-and-Excitation)注意力机制或CBAM(Convolutional Block Attention Module)注意力机制。具体修改如下:

初始网络(修改前)(部分):

(stages): Sequential(
    (0): BasicStage(
      (blocks): Sequential(
        (0): MLPBlock(
          (drop_path): Identity()
          (mlp): Sequential(
            (0): Conv2d(96, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
            (3): Conv2d(192, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
          )
          (spatial_mixing): Partial_conv3(
            (partial_conv3): Conv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          )
        )
      )
    )
    (1): PatchMerging(
      (reduction): Conv2d(96, 192, kernel_size=(2, 2), stride=(2, 2), bias=False)
      (norm): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (2): BasicStage(
      (blocks): Sequential(
        (0): MLPBlock(
          (drop_path): DropPath(drop_prob=0.008)
          (mlp): Sequential(
            (0): Conv2d(192, 384, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
            (3): Conv2d(384, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)
          )
          (spatial_mixing): Partial_conv3(
            (partial_conv3): Conv2d(48, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          )
        )
        (1): MLPBlock(
          (drop_path): DropPath(drop_prob=0.017)
          (mlp): Sequential(
            (0): Conv2d(192, 384, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
            (3): Conv2d(384, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)
          )
          (spatial_mixing): Partial_conv3(
            (partial_conv3): Conv2d(48, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          )
        )
      )
    )
    (3): PatchMerging(
      (reduction): Conv2d(192, 384, kernel_size=(2, 2), stride=(2, 2), bias=False)
      (norm): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )

修改后:

FasterNet(
  (patch_embed): PatchEmbed(
    (proj): Sequential(
      (0): Conv2d(3, 48, kernel_size=(4, 4), stride=(4, 4), bias=False)
      (1): eca_block(
        (avg_pool): AdaptiveAvgPool2d(output_size=1)
        (conv): Conv1d(1, 1, kernel_size=(3,), stride=(1,), padding=(1,), bias=False)
        (sigmoid): Sigmoid()
      )
    )
    (norm): Sequential(
      (0): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (stages): Sequential(
    (0): Sequential(
      (0): Conv2d(48, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(96, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (4): Conv2d(48, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (5): SEBlock(
        (avg_pool): AdaptiveAvgPool2d(output_size=1)
        (fc): Sequential(
          (0): Linear(in_features=48, out_features=3, bias=True)
          (1): ReLU(inplace=True)
          (2): Linear(in_features=3, out_features=48, bias=True)
          (3): Sigmoid()
        )
      )
    )
    (1): Sequential(
      (0): Conv2d(48, 96, kernel_size=(2, 2), stride=(2, 2), bias=False)
      (1): Conv2d(96, 192, kernel_size=(2, 2), stride=(2, 2), bias=False)
      (2): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (2): BasicStage(
      (blocks): Sequential(
        (0): MLPBlock(
          (drop_path): DropPath(drop_prob=0.008)
          (mlp): Sequential(
            (0): Conv2d(192, 384, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
            (3): Conv2d(384, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)
          )
          (spatial_mixing): Partial_conv3(
            (partial_conv3): Conv2d(48, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          )
        )
        (1): MLPBlock(
          (drop_path): DropPath(drop_prob=0.017)
          (mlp): Sequential(
            (0): Conv2d(192, 384, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
            (3): Conv2d(384, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)
          )
          (spatial_mixing): Partial_conv3(
            (partial_conv3): Conv2d(48, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          )
        )
      )
    )

可以看出,我在若干个BasicStage适当减少了网络的宽度,并且在MLPBlock中引入SE,提高模型表达能力同时减少计算复杂度。然后,按照相同的训练参数重新训练,结果如下:

训练集acc最高达99.70%! 验证集acc最高达98.97%!收敛速度和训练速度都有所提升。

相关参数:Params:13,709,817,FLOPs: 538492176.0,Throughput: 50.350035833086245 ops/s,3个参数都显著减小了,模型复杂度降低,训练速度加快,同时,由于添加了多种注意力机制,模型的捕捉特征的能力上升了,模型改进成功!后续还可以进一步剪枝进一步轻量化。

四、总结

通过减小模型的特征提取层的宽度和注意力机制的添加,提升模型的分类能力,使得模型更利于适合自己的分类任务。后续还可以通过增强数据和模型剪枝等手段进一步优化模型,大家可以自己动手试试。

  • 0
    点赞
  • 14
    收藏
    觉得还不错? 一键收藏
  • 4
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值