VMamba模型

摘要

本周阅读了 VMamba: Visual State Space ModelVMamba 这篇文献,VMamba是一种通用视觉主干,具有基于 SSM 的块,用于高效的视觉表示学习。 VMamba 在降低注意力计算复杂性方面的有效性很大程度上归功于 S6 模型中存在的选择性扫描机制,也称为选择性 SSM。与允许在上下文中进行密集信息路由的传统注意力计算方法不同,S6 要求一维数组(例如文本序列)中的每个元素仅通过压缩隐藏状态来获取上下文知识,从而将二次复杂度降低为线性复杂度,同时实验结果证明了 VMamba 在各种视觉感知任务中的良好性能,凸显了与现有基准模型相比,其在输入缩放效率方面的显着优势。本文将详细介绍 VMamba

Abstract

This week read the paper VMamba: Visual State Space ModelVMamba, a generalized visual backbone with SSM-based blocks for efficient visual representation learning. VMamba’s effectiveness in reducing the complexity of attentional computation is largely attributed to the selective scanning mechanism, also known as selective SSM, present in the S6 model.Unlike traditional attentional computation methods that allow for dense information routing in context, S6 requires that each element in a one-dimensional array (e.g., a sequence of text) acquires contextual knowledge by compressing the hidden state only, thereby reducing the quadratic complexity to linear complexity, while experimental results demonstrate the good performance of VMamba in a variety of visual perception tasks, highlighting its significant advantage in input scaling efficiency over existing benchmark models. In this paper, we present a detailed description of VMamba

1. VMamba模型

文献出处:VMamba: Visual State Space Model

1.1 文献摘要

CNN和VIT一直以来都是视觉领域的骨干网络,虽然 ViT 最近因其卓越的拟合能力而比 CNN 获得了突出地位,但其可扩展性在很大程度上受到注意力计算的二次复杂度的限制。

作者在本文提出了 VMamba,目的是为了将计算复杂度降低到线性,同时保留 ViT 的优势特征,同时也引入了交叉扫描模块(CSM),以实现具有全局感受野的 2D 图像空间中的 1D 选择性扫描。

实验结果证明了 VMamba 在各种视觉感知任务中的良好性能,凸显了与现有基准模型相比,其在输入缩放效率方面的显着优势。

1.2 研究背景

最近,状态空间模型(SSM)在自然语言处理(NLP)任务中展示了具有线性复杂性的长序列建模的巨大潜力。

作者提出了 VMamba,这是一种通用视觉主干,具有基于 SSM 的块,用于高效的视觉表示学习。 VMamba 在降低注意力计算复杂性方面的有效性很大程度上归功于 S6 模型中存在的选择性扫描机制,也称为选择性 SSM。与允许在上下文中进行密集信息路由的传统注意力计算方法不同,S6 要求一维数组(例如文本序列)中的每个元素仅通过压缩隐藏状态来获取上下文知识,从而将二次复杂度降低为线性复杂度。

然而,由于视觉数据的二维性质,单个扫描过程很难同时捕获不同方向上的依赖性信息,从而导致感受野受到限制。 我们将此问题称为“方向敏感”问题,并建议通过新引入的交叉扫描模块(CSM)来解决它。 CSM 不是以单向模式(列向或行向)遍历图像特征图的空间域,而是采用四向扫描策略,即从左上角和右下角开始遍历整个特征 映射到相反的位置(如下图)。 该策略确保特征图中的每个元素集成来自不同方向的所有其他位置的信息,从而在不增加计算复杂度的情况下实现全局感受野。
在这里插入图片描述

1.3 状态空间模型(SSM)

SSM 可以被视为线性时不变 (LTI) 系统,它通过隐藏状态 h(t) ε CN 将输入刺激 u(t) ε RL 映射到输出响应 y(t) ε RL。 它们通常被表述为线性常微分方程 (ODE)
在这里插入图片描述
离散化 状态空间模型(SSM)作为连续时间模型,在集成到深度学习算法中时面临着巨大的挑战。为了克服这个障碍,离散化过程势在必行。
在这里插入图片描述
作者首先使用 CSM(扫描扩展)扫描图像。然后通过 S6 块单独处理四个结果特征,并将四个输出特征合并(扫描合并)以构建最终的 2D 特征图。

通过 SS2D 模块传递数据涉及三个步骤:交叉扫描、使用 S6 块进行选择性扫描以及交叉合并。 给定输入数据,SS2D 首先沿着四个不同的遍历路径(即交叉扫描)将图像块展开为序列,使用单独的 S6 块并行处理每个块序列,然后重塑并合并结果序列以形成输出图 (即交叉合并)。 通过采用互补的遍历路径,SS2D使图像中的每个像素能够有效地整合来自不同方向的所有其他像素的信息,从而促进全局感受野的建立。

1.4 VMamba架构

VMamba-Tiny 的架构概述如下图所示。 VMamba 首先使用 Stem 模块将输入图像划分为图块,从而生成空间维度为 H 4 × W 4 \frac{H}{4} \times \frac{W}{4} 4H×4W 的 2D 特征图。
在这里插入图片描述
随后,多个网络阶段,每个阶段由 VSS 块组成,前面是下采样层(第一阶段除外),用于创建分辨率为 H 8 × W 8 \frac{H}{8} \times \frac{W}{8} 8H×8W H 16 × W 16 \frac{H}{16} \times \frac{W}{16} 16H×16W H 32 × W 32 \frac{H}{32} \times \frac{W}{32} 32H×32W。 下采样操作是通过补丁合并进行的,VSS块的详细结构如下图所示:
在这里插入图片描述
普通 VSS 块的结构如下图所示,这两个块都可以看作具有跳跃连接的残差网络。 残差网络包含两个分支:一个用于使用 3 × 3 深度卷积层进行特征提取,另一个由线性映射和激活层组成,激活层计算乘性门控信号。 Mamba 和普通 VSS 模块之间的主要区别在于用 SS2D 模块替换了 S6 模块,这使得选择性扫描能够适应 2D 视觉数据。
在这里插入图片描述
尽管在长序列建模方面效率很高,但基于 SSM 的架构 [14] 在处理较小规模的输入时经常会遇到计算速度降低的情况,这可能会限制 VMamba 的实际用途。

如下图所示,普通 VMamba-Tiny 模型实现了 426 个图像/秒的吞吐量,包含 22.9M 个参数和 5.6G FLOP(如果选择性扫描操作可以实现,FLOP 将降至 4.5G) 由单个 for 循环实现)。 低吞吐量和高内存开销给VMamba的实际部署带来了挑战。 因此,为了提高其推理速度,人们付出了巨大的努力,主要集中在实现细节和架构设计方面的进步。
在这里插入图片描述
从VMamba V0到V2,我们先后在torch.autograd.Function中实现了CSM,然后在Triton中重新实现了它。 这些修改有助于将吞吐量从 426 增加到 467。然后,在 V3 中,我们调整了与选择性扫描操作相关的 CUDA 实现,以适应 float16 输入张量并生成具有 float32 数据类型的输出张量。 与处理 float32 数据类型张量的实现相比,此调整提高了性能,特别是在训练期间,同时与对输入和输出张量使用 float16 相比,还实现了更高的数值稳定性。 此外,在 V4 和 V5 中,我们用线性变换(即 torch.nn.function.linear)替代了选择性扫描中相对较慢的 einsum 操作。 我们还采用了(B,C,H,W)的张量布局来消除不必要的数据排列。 这些变化导致吞吐量增加了 49.5%(从 426 增加到 637),并且不影响其他指标,例如参数数量、FLOP 和 ImageNet-1K 上的分类性能。

1.5 实验

1.5.1 ImageNet-1K 上的图像分类

我们使用 ImageNet-1K 数据集评估 VMamba 在图像分类方面的性能。 遵循[31]中概述的评估协议,VMamba-T/S/B模型从头开始训练300个epoch,前20个epoch专门用于预热,批量大小为1024。训练过程使用AdamW 优化器[34],贝塔设置为(0.9,0.999),动量为0.9,余弦衰减学习率调度器,初始学习率为1×10−3,权重衰减为0.05。 还应用了标签平滑 (0.1) 和指数移动平均 (EMA) 等其他技术。 除此之外,没有采用进一步的培训技术。

下表总结了 VMamba 与 ImageNet-1K 上基准骨干模型的比较结果。很明显,在相似的 FLOP 下,VMamba-T 的性能达到 82.5%,超过 RegNetY-4G 2.5%,超过 DeiT-S 2.5%。 2.7%,Swin-T 1.2%。 值得注意的是,VMamba 的这些性能优势在小型和基本规模模型中始终存在。 具体来说,VMamba-S 的 top-1 准确率达到 83.6%,比 RegNetY-8G 提高 1.9%,比 Swin-S 提高 0.6%。 同时,VMamba-B 的 top-1 准确率达到 83.9%,超过 RegNetY-16G 1.0%,超过 DeiT-B 0.6%。 在计算效率方面,虽然现有的基于 SSM 的视觉模型通常仅在大规模输入 [68](例如 1024 × 1024)下才表现出明显更好的吞吐量,但 VMamba-T 即使在输入分辨率为 224 × 224。这种性能更好,或者至少与最先进的方法相当,并且这种优势在 VMamba-S 和 VMamba-B 中仍然存在。 值得注意的是,随着输入大小从 224 × 224 扩展到 1024 × 1024,VMamba 相对于现有方法的优势变得更加明显,如表 4 所示。后续章节将对此主题进行进一步讨论。
在这里插入图片描述

1.5.2 COCO 上的物体检测

我们使用 MSCOCO 2017 数据集评估 VMamba 在对象检测方面的性能。 我们的训练框架是使用 MMDetection 库构建的,并且我们遵循 Swin中使用的超参数和 Mask-RCNN 检测器。 具体来说,我们采用 AdamW 优化器并对 12 和 36 epoch 的预训练分类模型(在 ImageNet-1K 上)进行微调。 VMamba-T/S/B 的丢弃路径率分别设置为 0.2%/0.3%/0.5%。 学习率初始化为 1×10−4,并在第 9 和 11 epoch 减少 10×。 我们实现了批量大小为 16 的多尺度训练和随机翻转,这与目标检测评估的既定实践一致。

VMamba 在 COCO 上的框/掩模平均精度 (AP) 方面保持优势,无论采用何种训练计划(12 或 36 epoch)。 具体来说,通过 12 epoch 的微调计划,VMamba-T/S/B 模型实现了 47.4%/48.7%/49.2% 的目标检测 mAP,超过了 Swin-T/S/B 4.7%/3.9%/2.3 % mAP 和 ConvNeXt-T/S/B 分别提高 3.2%/3.3%/2.2% mAP。 在相同配置下,VMambaT/S/B 的实例分割 mIoU 为 42.7%/43.7%/43.9%,比 Swin-T/S/B 高出 3.4%/2.8%/1.6% mIoU,而 ConvNeXt-T/S/ B 分别为 2.6%/1.9%/1.3% mIoU。 此外,VMamba 在多尺度训练的 36 epoch 微调方案下仍然具有优势,如表 2 所示。与 Swin [32]、ConvNeXt [33]、PVTv2 [55] 和 ViT 等同行相比 [12](使用适配器),VMamba-T/S 表现出卓越的性能,在对象检测上分别实现了 48.9%/49.9% mAP,在实例分割上分别实现了 43.7%/44.2% mIoU。 这些结果强调了 VMamba 在具有密集预测的下游任务中实现有希望的性能的潜力。
在这里插入图片描述

总结

本文介绍了 VMamba,这是一种多功能主干网络,专为使用状态空间模型 (SSM) 进行高效视觉表示学习而设计。 VMamba 的主要目标是将选择性 SSM 的优点(包括全局感受野、输入相关的加权参数和线性计算复杂性)融入视觉数据处理中。 具体来说,我们提出交叉扫描模块(CSM)来弥合一维选择性扫描和二维视觉数据之间的差距,并通过数学推导和定性可视化说明其与注意力机制的关系及其在实现全局感受野方面的有效性 。 此外,我们通过改进技术实现和架构设计,显着提高了 VMamba 的推理速度。 VMamba 系列(包括 VMamba-T/S/B 模型)的有效性已通过大量实验和消融研究得到证明,超越了流行的 CNN 和视觉 Transformer 的性能。 此外,VMamba 随着输入分辨率的提高而表现出卓越的可扩展性,在保持线性计算复杂性的同时表现出最小的性能下降。

下周我将具体通过pytorch实现这个网络架构,加油~

2. pytorch练习

数据集处理

import os
from shutil import copy, rmtree
import random


def mk_file(file_path: str):
    if os.path.exists(file_path):
        # 如果文件夹存在,则先删除原文件夹在重新创建
        rmtree(file_path)
    os.makedirs(file_path)


def main():
    # 保证随机可复现
    random.seed(0)

    # 将数据集中10%的数据划分到验证集中
    split_rate = 0.1

    # 指向你解压后的flower_photos文件夹
    cwd = os.getcwd()
    data_root = os.path.join(cwd, "CUB_200_2011")
    origin_CUB_path = os.path.join(data_root, "images")
    assert os.path.exists(origin_CUB_path), "path '{}' does not exist.".format(origin_CUB_path)

    CUB_class = [cla for cla in os.listdir(origin_CUB_path)
                    if os.path.isdir(os.path.join(origin_CUB_path, cla))]

    # 建立保存训练集的文件夹
    train_root = os.path.join(data_root, "train")
    mk_file(train_root)
    for cla in CUB_class:
        # 建立每个类别对应的文件夹
        mk_file(os.path.join(train_root, cla))

    # 建立保存验证集的文件夹
    val_root = os.path.join(data_root, "val")
    mk_file(val_root)
    for cla in CUB_class:
        # 建立每个类别对应的文件夹
        mk_file(os.path.join(val_root, cla))

    for cla in CUB_class:
        cla_path = os.path.join(origin_CUB_path, cla)
        images = os.listdir(cla_path)
        num = len(images)
        # 随机采样验证集的索引
        eval_index = random.sample(images, k=int(num*split_rate))
        for index, image in enumerate(images):
            if image in eval_index:
                # 将分配至验证集中的文件复制到相应目录
                image_path = os.path.join(cla_path, image)
                new_path = os.path.join(val_root, cla)
                copy(image_path, new_path)
            else:
                # 将分配至训练集中的文件复制到相应目录
                image_path = os.path.join(cla_path, image)
                new_path = os.path.join(train_root, cla)
                copy(image_path, new_path)
            print("\r[{}] processing [{}/{}]".format(cla, index+1, num), end="")  # processing bar
        print()

    print("processing done!")


if __name__ == '__main__':
    main()

参数设置

import argparse

def str2bool(v):
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')


def get_args():
    parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
    parser.add_argument('data', metavar='DIR', nargs='?', default='imagenet',
                        help='path to dataset (default: imagenet)')
    parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50',
                            help='models architecture: default: resnet18)') # arch是需要加载的预训练模型名
    parser.add_argument("--optimizer", default="SGD", type=str, help='["SGD", "Adam", "AdamW"]')
    parser.add_argument('-j', '--workers', default=16, type=int, metavar='N',
                        help='number of data loading workers (default: 4)')
    parser.add_argument('--epochs', default=120, type=int, metavar='N',
                        help='number of total epochs to run')
    parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
                        help='manual epoch number (useful on restarts)')
    parser.add_argument('-b', '--batch-size', default=16, type=int,
                        metavar='N',
                        help='mini-batch size (default: 256), this is the total '
                             'batch size of all GPUs on the current node when '
                             'using Data Parallel or Distributed Data Parallel')
    # optimizer
    parser.add_argument('--lr', '--learning-rate', default=0.005, type=float,
                        metavar='LR', help='initial learning rate', dest='lr')
    parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
                        help='momentum')
    parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
                        metavar='W', help='weight decay (default: 1e-4)',
                        dest='weight_decay')

    # center loss
    parser.add_argument('--parts', default=32, type=int,
                        metavar='N', help='number of parts (default: 32)')
    parser.add_argument('--alpha', default=0.95, type=float,
                        metavar='N', help='weight for BAP loss')

    # scheduler
    parser.add_argument('--decay-step', default=20, type=int, metavar='N',
                        help='learning rate decay step')
    parser.add_argument('--gamma', default=0.5, type=float, metavar='M',
                        help='gamma')
    parser.add_argument('-p', '--print-freq', default=10, type=int,
                        metavar='N', help='print frequency (default: 10)')
    parser.add_argument('--resume', default='', type=str, metavar='PATH',
                        help='path to latest checkpoint (default: none)')
    parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
                        help='evaluate models on validation set')
    parser.add_argument('--pretrained', dest='pretrained', action='store_true',
                        help='use pre-trained models')
    # parser.add_argument('--world-size', default=-1, type=int,
    #                     help='number of nodes for distributed training')
    # parser.add_argument('--rank', default=-1, type=int,
    #                     help='node rank for distributed training')
    # parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str,
    #                     help='url used to set up distributed training')
    # parser.add_argument('--dist-backend', default='nccl', type=str,
    #                     help='distributed backend')
    parser.add_argument('--seed', default=1, type=int,
                        help='seed for initializing training. ')
    parser.add_argument('--gpu', default=1, type=int,
                        help='GPU id to use.')
    # parser.add_argument('--multiprocessing-distributed', action='store_true',
    #                     help='Use multi-processing distributed training to launch '
    #                          'N processes per node, which has N GPUs. This is the '
    #                          'fastest way to use PyTorch for either single node or '
    #                          'multi node data parallel training')
    # parser.add_argument('--dummy', action='store_true', help="use fake data to benchmark")

    # training
    parser.add_argument('--dataset', type=str, default='CUB',choices=['CUB','Cars','Aircraft'],
                        help='dataset for FGVC')
    parser.add_argument('--name', type=str, default='test_case')
    parser.add_argument('--lr_step', type=int, default=30)  # lr_step
    parser.add_argument('--resize-size', type=int, default=512, help='validation resize size')
    parser.add_argument('--crop-size', type=int, default=448, help='validation crop size')
    parser.add_argument('--VAL-CROP', type=str2bool, nargs='?', const=True, default=True,
                        help='Evaluation method'
                             'If True, Evaluate on 256x256 resized and center cropped 224x224 map'
                             'If False, Evaluate on directly 224x224 resized map')
    # CAM
    parser.add_argument('--cam-thr', type=float, default=0.2, help='cam threshold value(default=0.15)')

    # Random Erasing
    parser.add_argument('--p', default=0.5, type=float, help='Random Erasing probability')
    parser.add_argument('--sh', default=0.4, type=float, help='max erasing area')
    parser.add_argument('--r1', default=0.3, type=float, help='aspect of erasing area')


    args = parser.parse_args()
    return args


Res2Net模型


import torch.nn as nn
import math
import torch.utils.model_zoo as model_zoo
import torch
import torch.nn.functional as F
__all__ = ['Res2Net', 'res2net50']


model_urls = {
    'res2net50_26w_4s': 'https://shanghuagao.oss-cn-beijing.aliyuncs.com/res2net/res2net50_26w_4s-06e79181.pth',
    'res2net50_48w_2s': 'https://shanghuagao.oss-cn-beijing.aliyuncs.com/res2net/res2net50_48w_2s-afed724a.pth',
    'res2net50_14w_8s': 'https://shanghuagao.oss-cn-beijing.aliyuncs.com/res2net/res2net50_14w_8s-6527dddc.pth',
    'res2net50_26w_6s': 'https://shanghuagao.oss-cn-beijing.aliyuncs.com/res2net/res2net50_26w_6s-19041792.pth',
    'res2net50_26w_8s': 'https://shanghuagao.oss-cn-beijing.aliyuncs.com/res2net/res2net50_26w_8s-2c7c9f12.pth',
    'res2net101_26w_4s': 'https://shanghuagao.oss-cn-beijing.aliyuncs.com/res2net/res2net101_26w_4s-02a759a1.pth',
}


class Bottle2neck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None, baseWidth=26, scale=4, stype='normal'):
        """ 构造函数
        参数:
            inplanes: 输入通道维度
            planes: 输出通道维度
            stride: 卷积步长。替代池化层。
            downsample: 当stride = 1时为None
            baseWidth: conv3x3的基本宽度
            scale: 尺度数量。
            type: 'normal': 正常设置。 'stage': 新阶段的第一个块。
        """
        super(Bottle2neck, self).__init__()

        # 计算卷积核的宽度
        width = int(math.floor(planes * (baseWidth / 64.0)))
        # 第一个1x1卷积层
        self.conv1 = nn.Conv2d(inplanes, width * scale, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(width * scale)

        # 计算重复次数
        if scale == 1:
            self.nums = 1
        else:
            self.nums = scale - 1

        # 如果是新阶段的第一个块,则使用平均池化层进行下采样
        if stype == 'stage':
            self.pool = nn.AvgPool2d(kernel_size=3, stride=stride, padding=1)

        # 定义重复的卷积层和BN层
        convs = []
        bns = []
        for i in range(self.nums):
            convs.append(nn.Conv2d(width, width, kernel_size=3, stride=stride, padding=1, bias=False))
            bns.append(nn.BatchNorm2d(width))
        # 创建了两个 nn.ModuleList 对象 self.convs 和 self.bns,用于存储多个卷积层和批量归一化层。
        self.convs = nn.ModuleList(convs)
        self.bns = nn.ModuleList(bns)

        # 最后一个1x1卷积层
        self.conv3 = nn.Conv2d(width * scale, planes * self.expansion, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes * self.expansion)

        # 激活函数
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stype = stype
        self.scale = scale
        self.width = width

    def forward(self, x):
        residual = x

        # 第一个1x1卷积层的计算
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        # 将输出按照宽度进行分割
        spx = torch.split(out, self.width, 1)
        for i in range(self.nums):
            # 如果是第一个块或者是新阶段的第一个块,则直接取分割后的部分
            if i == 0 or self.stype == 'stage':
                sp = spx[i]
            else:
                # 否则,累加之前的部分
                sp = sp + spx[i]
            # 对部分进行卷积、BN和ReLU操作
            sp = self.convs[i](sp)
            sp = self.relu(self.bns[i](sp))
            if i == 0:
                out = sp
            else:
                # 将处理后的部分拼接起来
                out = torch.cat((out, sp), 1)
        # 如果尺度不为1且为正常设置,将最后一个部分拼接到一起
        if self.scale != 1 and self.stype == 'normal':
            out = torch.cat((out, spx[self.nums]), 1)
        # 如果尺度不为1且为新阶段的第一个块,则对最后一个部分进行平均池化并拼接
        elif self.scale != 1 and self.stype == 'stage':
            out = torch.cat((out, self.pool(spx[self.nums])), 1)

        # 最后一个1x1卷积层的计算
        out = self.conv3(out)
        out = self.bn3(out)

        # 如果存在下采样,则对输入进行下采样
        if self.downsample is not None:
            residual = self.downsample(x)

        # 残差连接并进行ReLU激活
        out += residual
        out = self.relu(out)

        return out


class Res2Net(nn.Module):

    def __init__(self, block, layers, baseWidth=26, scale=4, num_classes=1000):
        # 初始化Res2Net模型
        self.inplanes = 64  # 设置输入通道数为64
        self.baseWidth = baseWidth
        self.scale = scale
        super(Res2Net, self).__init__()  # 调用父类的构造函数

        # 定义网络的第一层:7x7的卷积层,输入通道数为3,输出通道数为64,步长为2,填充为3
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        # Batch Normalization层,对每个channel的数据进行标准化
        self.bn1 = nn.BatchNorm2d(64)
        # 激活函数ReLU
        self.relu = nn.ReLU(inplace=True)
        # 最大池化层,窗口大小为3x3,步长为2,填充为1
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        # 定义4个Res2Net的阶段(stage)
        self.layer1 = self._make_layer(block, 64, layers[0])  # 第一个阶段,输出通道数为64
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)  # 第二个阶段,输出通道数为128,步长为2
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)  # 第三个阶段,输出通道数为256,步长为2
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)  # 第四个阶段,输出通道数为512,步长为2

        # 全局平均池化层,将每个通道的特征图变成一个数
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        # 全连接层,将512维的特征向量映射到num_classes维的向量,用于分类
        self.fc = nn.Linear(512 * block.expansion, num_classes)

        # 初始化网络参数
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                # 使用kaiming正态分布初始化卷积层参数
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                # 将Batch Normalization层的权重初始化为1,偏置初始化为0
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def _make_layer(self, block, planes, blocks, stride=1):
        # 构建Res2Net的一个阶段(stage),包含多个block
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            # 如果输入输出通道数不一致,或者步长不为1,需要添加下采样层
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion),
            )

        # 构建阶段的每个block
        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample=downsample,
                            stype='stage', baseWidth=self.baseWidth, scale=self.scale))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes, baseWidth=self.baseWidth, scale=self.scale))

        return nn.Sequential(*layers)

    def forward(self, x):
        # 定义前向传播过程
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)

        return x


def res2net50(pretrained=False, **kwargs):
    """Constructs a Res2Net-50 model.
    Res2Net-50 refers to the Res2Net-50_26w_4s.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = Res2Net(Bottle2neck, [3, 4, 6, 3], baseWidth = 26, scale = 4, **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['res2net50_26w_4s']))
    return model

def res2net50_26w_4s(pretrained=False, **kwargs):
    """Constructs a Res2Net-50_26w_4s model.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = Res2Net(Bottle2neck, [3, 4, 6, 3], baseWidth = 26, scale = 4, **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['res2net50_26w_4s']))
    return model

def res2net101_26w_4s(pretrained=False, **kwargs):
    """Constructs a Res2Net-50_26w_4s model.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = Res2Net(Bottle2neck, [3, 4, 23, 3], baseWidth = 26, scale = 4, **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['res2net101_26w_4s']))
    return model

def res2net50_26w_6s(pretrained=False, **kwargs):
    """Constructs a Res2Net-50_26w_4s model.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = Res2Net(Bottle2neck, [3, 4, 6, 3], baseWidth = 26, scale = 6, **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['res2net50_26w_6s']))
    return model

def res2net50_26w_8s(pretrained=False, **kwargs):
    """Constructs a Res2Net-50_26w_4s model.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = Res2Net(Bottle2neck, [3, 4, 6, 3], baseWidth = 26, scale = 8, **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['res2net50_26w_8s']))
    return model

def res2net50_48w_2s(pretrained=False, **kwargs):
    """Constructs a Res2Net-50_48w_2s model.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = Res2Net(Bottle2neck, [3, 4, 6, 3], baseWidth = 48, scale = 2, **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['res2net50_48w_2s']))
    return model

def res2net50_14w_8s(pretrained=False, **kwargs):
    """Constructs a Res2Net-50_14w_8s model.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = Res2Net(Bottle2neck, [3, 4, 6, 3], baseWidth = 14, scale = 8, **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['res2net50_14w_8s']))
    return model



if __name__ == '__main__':
    images = torch.rand(1, 3, 224, 224).cuda(0)
    model = res2net50_48w_2s(pretrained=False)
    model = model.cuda(0)
    print(model(images).size())
    print(model)

训练代码

# coding:utf-8 允许中文注释
import numpy as np
import os

import torchvision

os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
import time
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
from option import get_args
from model import resnet50
from util import AverageMeter, accuracy, save_checkpoint, load_model_checkpoint
from res2net import res2net50_48w_2s

def init_seeds(seed=0):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    if seed == 0:
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False


best_acc1 = 0.
def repeat_channels(x):
    # 这个函数将输入的 PIL 图像 x 复制到三个通道,模拟 RGB 图像
    return x.repeat(3, 1, 1)

def main():
    print("Start...")
    global best_acc1
    args = get_args()
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    init_seeds(seed=0) # set random seed

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))  # 训练所用的GUP ID

    # directory for save
    args.log_folder = os.path.join('log', 'res2net50_48w_2s')
    if not os.path.exists(args.log_folder):
        os.makedirs(args.log_folder)



    if args.dataset == "CUB" and args.arch == "resnet50":
        channels = 2048
        num_classes = 200
        data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))  # get data root path
        image_path = os.path.join(data_root, '/data/tgf/resnet/Data/')
        # train_dir = '/data/tgf/resnet/Data/trian'
        # valid_dir = '/data/tgf/resnet/Data/test'
    elif args.dataset == 'Cars' and args.arch == "resnet50":
        channels = 2048
        num_classes = 196
        data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))  # get data root path
        image_path = os.path.join(data_root, '/tgf/resnet/CUB_200_2011/dataset')
        # train_dir = '/learn_pytorch/resnet/Data/trian'
        # valid_dir = '/learn_pytorch/resnet/Data/test'
    elif args.dataset == "Aircraft" and args.arch == "resnet50":
        channels = 2048
        num_classes = 100
        data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))  # get data root path
        image_path = os.path.join(data_root, '/data/tgf/resnet/Data')
        # train_dir = '/learn_pytorch/resnet/Data/trian'
        # valid_dir = '/learn_pytorch/resnet/Data/test'
    else:
        raise Exception("No dataset named {}".format(args.dataset))

    # Model
    print("=> creating model '{}'".format(args.arch))
    print("num_classes ", num_classes)
    model = res2net50_48w_2s(pretrained=True)
    # model_weight_path = "./resnet50_pre.pth"
    # assert os.path.exists(model_weight_path), "file {} does not exist.".format(model_weight_path)
    # model.load_state_dict(torch.load(model_weight_path, map_location='cpu'))
    # change fc layer structure
    in_channel = model.fc.in_features
    model.fc = nn.Linear(in_channel, num_classes)
    model = model.cuda()

    cudnn.benchmark = True

    # Loading training/validation dataset
    train_transform = transforms.Compose([
        transforms.Resize((512, 512)),
        transforms.RandomCrop((448, 448)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        # transforms.Lambda(repeat_channels),
        transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ])
    test_transform = transforms.Compose([
        transforms.Resize((512, 512)),
        transforms.CenterCrop((448, 448)),  # RandomCrop for train and CenterCrop for test
        transforms.ToTensor(),
        # transforms.Lambda(repeat_channels),
        transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ])



    train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"), transform=train_transform)
    print("train_dataset为:",train_dataset)
    valid_dataset = datasets.ImageFolder(root=os.path.join(image_path, "test"), transform=test_transform)
    train_loader = DataLoader(dataset=train_dataset, batch_size=args.batch_size,
                              shuffle=True, num_workers=args.workers, pin_memory=True)
    # print("train_loader为:",train_loader)
    valid_loader = DataLoader(dataset=valid_dataset, batch_size=args.batch_size,
                              shuffle=False, num_workers=args.workers, pin_memory=True)

    print("using {} images for training, {} images for validation.".format(len(train_dataset), len(valid_dataset)))

    # define loss function (criterion), optimizer, and learning rate scheduler
    criterion = nn.CrossEntropyLoss().cuda()
    optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, nesterov=True, momentum=args.momentum,
                                weight_decay=args.weight_decay)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.decay_step, gamma=args.gamma)

    # optionally resume from a checkpoint
    if args.resume:
        model, optimizer = load_model_checkpoint(model, optimizer, args)

    def train(train_loader, model, criterion, optimizer, epoch, args):
        # AverageMeter for Performance
        losses = AverageMeter('Loss', ':.4e')
        top1 = AverageMeter('Acc@1', ':6.2f')
        top5 = AverageMeter('Acc@5', ':6.2f')
        DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        # Switch to train mode
        model.train()

        # lr = next(iter(optimizer.param_groups))['lr']
        train_bar = tqdm(train_loader)  # 训练集进度条
        for batch_idx, (inputs, targets) in enumerate(train_bar):
            idx = batch_idx
            inputs, targets = Variable(inputs).cuda(), Variable(targets).cuda()
            # inputs, targets = Variable(inputs), Variable(targets)

            # compute output
            outputs = model(inputs)  # 前向传播
            loss = criterion(outputs, targets)

            # # measure accuracy and record loss
            acc1, acc5 = accuracy(outputs, targets, topk=(1, 5))
            losses.update(loss.item(), inputs.size(0))
            top1.update(acc1[0], inputs.size(0))
            top5.update(acc5[0], inputs.size(0))

            # compute gradient and do SGD step
            optimizer.zero_grad()
            loss.backward()  # !!
            optimizer.step()

            # print info
            description = "[Train:{0:3d}/{1:3d}] Top1-cls: {2:6.2f}, Top5-cls: {3:6.2f}, Loss: {4:7.4f},". \
                format(epoch + 1, args.epochs, top1.avg, top5.avg, losses.avg)
            train_bar.set_description(desc=description)

        return top1.avg, losses.avg

    best_acc_epoch = 0
    for epoch in range(args.start_epoch, args.epochs):
        lr = next(iter(optimizer.param_groups))['lr']
        # ————————————————Train————————————————#
        train_acc1, train_losses = train(train_loader, model, criterion, optimizer, epoch, args)
        scheduler.step()  # 放到每个epoch训练完之后

        # tensorboard
        with SummaryWriter(log_dir=os.path.join(args.log_folder, 'no_seed/train'), comment='train') as writer:
            writer.add_scalar('Train/learning_rate', lr, epoch)
            writer.add_scalar('Train/train_acc1', train_acc1, epoch)
            writer.add_scalar('Train/train_loss', train_losses, epoch)
            writer.flush()
            writer.close()

        # ————————————————Test————————————————#
        val_acc1, val_losses = validate(valid_loader, model, criterion, epoch, args)  # Test!!!

        # tensorboard
        with SummaryWriter(log_dir=os.path.join(args.log_folder, 'no_seed/val'), comment='test') as writer:
            writer.add_scalar('Test/val_acc1', val_acc1, epoch)
            writer.add_scalar('Test/val_loss', val_losses, epoch)
            writer.flush()
            writer.close()

        is_best = val_acc1 > best_acc1  # True / False
        best_acc1 = max(val_acc1, best_acc1)
        # save_checkpoint({
        #     'epoch': epoch + 1,
        #     'arch': args.arch,
        #     'state_dict': model.state_dict(),
        #     'best_acc1': best_acc1,
        #     'optimizer': optimizer.state_dict(),
        #     # 'scheduler': scheduler.state_dict()
        # }, is_best, args.log_folder)

        if is_best:
            best_acc_epoch = epoch + 1
            savepath = "/data/tgf/resnet/log/resnet50_in_CUB/best.pth"
            torch.save(model, savepath)

        print("Until %d epochs, Best Acc@1 %.3f in the %d-th epoch" % (epoch + 1, best_acc1, best_acc_epoch))

    with open(os.path.join(args.log_folder, 'result.txt'), 'w') as file:
        file.write("best_acc1 {}".format(best_acc1))
    file.close()





def validate(val_loader, model, criterion, epoch, args):
    # AverageMeter for Performance
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    # DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # switch to evaluate mode
    model.eval()
    with torch.no_grad():
        val_bar = tqdm(val_loader)
        for batch_idx, (inputs, targets) in enumerate(val_bar):
            idx = batch_idx
            inputs, targets = Variable(inputs).cuda(), Variable(targets).cuda()

            # Compute output
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            # measure accuracy and record loss
            acc1, acc5 = accuracy(outputs, targets, topk=(1, 5))
            losses.update(loss.item(), inputs.size(0))
            top1.update(acc1[0], inputs.size(0))
            top5.update(acc5[0], inputs.size(0))

            # print info
            description = "[Valid:{0:3d}/{1:3d}] Top1-cls: {2:6.2f}, Top5-cls: {3:6.2f}, Loss: {4:7.4f}, ". \
                format(epoch + 1, args.epochs, top1.avg, top5.avg, losses.avg)
            val_bar.set_description(desc=description)

    return top1.avg, losses.avg


if __name__ == '__main__':
    main()

实验结果
在这里插入图片描述

  • 19
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

@默然

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值