ViG:图像分类领域前沿


✨✨ 欢迎大家来访Srlua的博文(づ ̄3 ̄)づ╭❤~✨✨

🌟🌟 欢迎各位亲爱的读者,感谢你们抽出宝贵的时间来阅读我的文章。

我是Srlua小谢,在这里我会分享我的知识和经验。🎥

希望在这里,我们能一起探索IT世界的奥妙,提升我们的技能。🔮

记得先点赞👍后阅读哦~ 👏👏

📘📚 所属专栏:传知代码论文复现

欢迎访问我的主页:Srlua小谢 获取更多信息和资源。✨✨🌙🌙

​​

​​

目录

图片分类任务方法概述

卷积神经网络(CNN)

视觉Transformer(ViT)

视觉图神经网络(ViG)

ViG模型

图片切成patch

模型架构

图像输入

图结构生成

网络模块

图处理

特征变换

多尺度处理

输出头

ViG代码

PatchEmbedding

模型主体架构设计

核心代码

演示效果

附件使用

安装相应依赖包

获取cifa10数据集

运行代码


   本文所有资源均可在该地址处获取。

图片分类任务方法概述

卷积神经网络(CNN)

发展背景: CNN的出现标志着深度学习在图像识别领域的重大突破。最早的CNN模型可以追溯到1998年的LeNet,而2012年的AlexNet模型在ImageNet竞赛中取得优异成绩,使得CNN成为图像分类任务的主流方法。

分类方法优点:

局部感知野: 通过卷积操作,CNN能够捕捉图像的局部特征,减少参数数量。
参数共享: 卷积核在整张图像上共享,提高了模型的泛化能力。
平移不变性: CNN具有平移不变性,能够识别图像中的物体,即使它们的位置发生变化。

视觉Transformer(ViT)

发展背景: ViT于2020年被提出,借鉴了自然语言处理领域的Transformer架构,将自注意力机制应用于图像分类任务。

分类方法优点:

自注意力机制: 能够捕捉图像中的长距离依赖关系,提高分类准确性。
可扩展性: Transformer结构易于扩展,适用于大规模数据集。
并行计算: 自注意力机制使得ViT能够更好地利用并行计算资源。

视觉图神经网络(ViG)

发展背景: ViG的提出是为了解决CNN和ViT在处理不规则和复杂目标时的局限性。ViG将图像视为图结构,通过图卷积操作进行特征提取和分类。

分类方法优点:

灵活的图结构: ViG采用图结构表示图像,能够更好地处理不规则形状的物体,提高对复杂场景的识别能力。
图卷积操作: 通过图卷积,ViG能够有效地聚合和更新节点信息,捕捉局部和全局特征。
节点特征变换: FFN模块(多层感知器)用于节点特征变换,增强了模型的表达能力

ViG模型

图片切成patch


(a) Grid Structure
作用:
像素级信息捕获:通过将图像切分成均匀分布的小块(Patch),每个Patch代表图像的一个局部区域。
空间关系保持:保留了图像的空间布局信息,使得模型能够理解对象的位置和相对位置。
重要性:
经典方法的基础:这是许多传统计算机视觉算法的基本假设,包括早期的人工设计特征提取方法和现代的深度学习模型(如卷积神经网络CNN)。
简单直观:易于理解和实施,是初学者入门的好选择。
(b) Sequence Structure
作用:
序列化处理:将图像的Patch按某种顺序排列,形成一维序列。
时间维度模拟:虽然实际处理的是静态图像,但通过序列化的方式,可以引入类似于自然语言处理(NLP)领域的时间维度概念。
重要性:
Transformer的应用:这种结构特别适合于基于Transformer架构的方法,如Vision Transformer(ViT)。ViT等模型通过自注意力机制对序列化的Patch进行处理,从而有效地捕捉全局上下文信息。
灵活性提升:相比固定大小的卷积核,序列化处理允许模型关注任意距离的Patch之间的关系,提高了模型的灵活性和泛化能力。
© Graph Structure
作用:
非结构化数据建模:将图像中的Patch视为图中的节点,允许模型处理更加复杂和灵活的数据结构。
适应性强:能够更好地适应各种形状和尺寸的对象,尤其是对于那些不能很好地用网格或序列描述的情况。
重要性:
图神经网络优势:结合图神经网络(GNN)的优点,能够有效处理具有复杂拓扑结构的数据,如社交网络、分子结构等。
创新性突破:在视觉任务中引入图结构是一种创新尝试,有望带来新的突破和进展,特别是在需要精细分析和理解场景的情况下。

模型架构

图像输入

首先,从一张原始图像开始。在这个例子中,图像展示了一条鱼和一个人的部分身体。

图结构生成

接下来,将图像划分为若干个Patch,并将这些Patch作为图中的节点。每个节点代表图像的一部分,而边则表示这些部分之间的关联。红色圆圈内的节点可能表示图像的关键部分,比如鱼的身体或者人的衣服图案。

网络模块

然后,进入网络模块,该模块由两部分组成:图处理和特征变换。

图处理

在这一步骤中,模型会对图结构进行处理,以提取出各个Patch之间的关系和相互影响。这可以通过图卷积操作或其他类型的图神经网络技术完成。

特征变换

经过图处理之后,得到的特征会被送入特征变换模块。这里可能会涉及到一些标准的神经网络组件,如全连接层、激活函数等,目的是进一步提炼和转化所获得的信息。

多尺度处理

整个过程会重复多次(L次),每次都会产生一个新的特征图。这样做的好处是可以从不同的层次和角度来观察和理解图像内容,增强模型的表现力。

输出头

最后,所有经过多轮处理后的特征被整合起来,传递给输出头(Head for recognition)。这个输出头负责最终的识别任务,可能是分类、回归或者其他类型的问题。

ViG代码

PatchEmbedding

class Stem(nn.Module):
    def __init__(self, img_size=224, in_dim=3, out_dim=768, act='relu'):
        super().__init__()
        self.convs = nn.Sequential(
            nn.Conv2d(in_dim, out_dim//8, 3, stride=2, padding=1),
            nn.BatchNorm2d(out_dim//8),
            act_layer(act),
            nn.Conv2d(out_dim//8, out_dim//4, 3, stride=2, padding=1),
            nn.BatchNorm2d(out_dim//4),
            act_layer(act),
            nn.Conv2d(out_dim//4, out_dim//2, 3, stride=2, padding=1),
            nn.BatchNorm2d(out_dim//2),
            act_layer(act),
            nn.Conv2d(out_dim//2, out_dim, 3, stride=2, padding=1),
            nn.BatchNorm2d(out_dim),
            act_layer(act),
            nn.Conv2d(out_dim, out_dim, 3, stride=1, padding=1),
            nn.BatchNorm2d(out_dim),
        )

    def forward(self, x):
        x = self.convs(x)
        return x

模型主体架构设计

self.backbone = Seq(*[Seq(Grapher(channels, num_knn[i], 1, conv, act, norm,
                 bias, stochastic, epsilon, 1, drop_path=dpr[i]),
                 FFN(channels, channels * 4, act=act, drop_path=dpr[i])
) for i in range(self.n_blocks)])

核心代码

聚合特征

class MRConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, act='relu', norm=None, bias=True):
        super(MRConv2d, self).__init__()
        self.nn = BasicConv([in_channels*2, out_channels], act, norm, bias)

    def forward(self, x, edge_index, y=None):
        print(x.shape, edge_index.shape)
        x_i = batched_index_select(x, edge_index[1])
        print(x_i.shape)
        if y is not None:
            x_j = batched_index_select(y, edge_index[0])
        else:
            x_j = batched_index_select(x, edge_index[0])
            print(x_j.shape)
        x_j, _ = torch.max(x_j - x_i, -1, keepdim=True)
        b, c, n, _ = x.shape
        x = torch.cat([x.unsqueeze(2), x_j.unsqueeze(2)], dim=2).reshape(b, 2 * c, n, _)
        print(x.shape)
        return self.nn(x)

演示效果

附件使用

安装相应依赖包

pip install -r requirements.txt

获取cifa10数据集

import torchvision
import torchvision.transforms as transforms

# transforms用于数据预处理
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

# 下载并加载训练数据集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                          shuffle=True, num_workers=2)

# 下载并加载测试数据集
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
                                         shuffle=False, num_workers=2)

# CIFAR-10数据集中的类别
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

运行代码

python train.py

​​

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值