【YOLOv8改进 - 注意力机制】Gather-Excite : 提高网络捕获长距离特征交互的能力

YOLOv8目标检测创新改进与实战案例专栏

专栏目录: YOLOv8有效改进系列及项目实战目录 包含卷积,主干 注意力,检测头等创新机制 以及 各种目标检测分割项目实战案例

专栏链接: YOLOv8基础解析+创新改进+实战案例

介绍

image-20240723152139314

摘要

虽然卷积神经网络(CNNs)中使用自下而上的局部操作符与自然图像的一些统计特性很好地匹配,但这也可能阻止这些模型捕捉上下文的长程特征交互。在这项工作中,我们提出了一种简单且轻量的方法,以更好地在CNNs中利用上下文信息。我们通过引入一对操作符来实现这一目标:聚集(gather),该操作符高效地聚合来自大空间范围的特征响应;激发(excite),将汇集的信息重新分配给局部特征。这些操作符在添加参数数量和计算复杂度方面都很便宜,并且可以直接集成到现有架构中以提高其性能。多个数据集上的实验表明,聚集-激发(gather-excite)操作符可以带来类似于增加CNN深度的好处,但成本仅为其一小部分。例如,我们发现带有聚集-激发操作符的ResNet-50在ImageNet上能够超越其101层的对应模型,而无需额外的可学习参数。我们还提出了一对参数化的聚集-激发操作符,这对进一步提高性能有帮助,并将其与最近引入的挤压-激励网络(Squeeze-and-Excitation Networks)联系起来,并分析这些变化对CNN特征激活统计的影响。

基本原理

Gather-Excite(简称GE)框架旨在增强卷积神经网络(CNNs)中对上下文的利用能力。它引入了两个主要操作符:gather和excite,这两个操作符协同工作,提高了网络捕获长距离特征交互的能力。

技术原理

1. 动机

传统的卷积神经网络主要使用局部操作符,这些操作符虽然高效,但在捕捉长距离依赖关系方面存在局限性。这是因为它们的感受野是局部的。尽管更深的层理论上具有更大的感受野,但实际上有效感受野要小得多。这一限制妨碍了CNN利用整个图像中分布的上下文信息。

2. Gather操作符 (ξG)

Gather操作符用于从较大的空间范围内聚合特征响应。它通过汇集广泛区域的信息,使网络能够收集上下文信息。这个操作符可以通过不同的池化方法实现,如平均池化,它对指定范围内的特征值进行平均。

3. Excite操作符 (ξE)

Excite操作符将聚合的信息重新分配给局部特征。这种重新分配通过根据聚合的上下文信息重新缩放原始输入特征来实现。Excite操作符使用门控机制(通常是sigmoid函数)来调整输入特征,使其受到聚合上下文的调节。

4. 在CNN中的整合

GE操作符轻量且易于整合到现有的CNN架构中。它们被插入到ResNet等网络的残差块中,就在与恒等分支求和之前。这样的整合提高了网络的表示能力,而不会显著增加计算负担。

实现细节

1. 无参数配对

在基础实现(GE-θ−)中,gather操作符使用平均池化来聚合特征,excite操作符使用sigmoid函数来调整这些聚合。这种方法不引入额外的可学习参数,并且显著提高了性能。

2. 参数化配对

为了进一步增强框架,参数化的gather操作符被引入,通过深度卷积来应用空间滤波到独立通道上。这种方法被称为GE-θ,为gather操作符添加了可学习参数,进一步提高了性能。

性能和优势

  1. ImageNet分类:实验表明,将GE操作符集成到ResNet-50中,其性能优于更深的ResNet-101,展示了上下文利用的效率。

  2. 泛化能力:GE框架在其他架构和任务中也表现出良好的泛化能力,如在MS COCO上的Faster R-CNN目标检测和CIFAR-10/100上的分类任务。

  3. 计算效率:这些操作符计算成本低,不会显著增加网络的参数数量或计算复杂度,适合资源受限的环境。

核心代码

task与yaml配置

详见:https://blog.csdn.net/shangyanaf/article/details/140637601

  • 29
    点赞
  • 28
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
以下是 PointNet++ 加注意力机制改进的代码示例: ```python import tensorflow as tf def get_attention_weight(x, y, dim): """ 获取注意力权重 :param x: 输入特征向量 :param y: 相关特征向量 :param dim: 特征向量维度 :return: 注意力权重 """ w = tf.Variable(tf.random_normal([dim, 1], stddev=0.1), name='attention_w') b = tf.Variable(tf.zeros([1]), name='attention_b') z = tf.matmul(tf.concat([x, y], axis=1), w) + b a = tf.nn.softmax(z) return a def get_attention_feature(x, y, dim): """ 获取注意力特征向量 :param x: 输入特征向量 :param y: 相关特征向量 :param dim: 特征向量维度 :return: 注意力特征向量 """ a = get_attention_weight(x, y, dim) f = tf.concat([x, y], axis=1) * a return f def pointnet_plus_plus_attention(x, k, mlp, is_training): """ PointNet++ 加注意力机制改进 :param x: 输入点云数据,shape为(batch_size, num_points, num_dims) :param k: k-NN 算法中的 k 值 :param mlp: 全连接网络结构 :param is_training: 是否为训练 :return: 输出结果,shape为(batch_size, num_points, mlp[-1]) """ num_points = x.get_shape()[1].value num_dims = x.get_shape()[-1].value with tf.variable_scope('pointnet_plus_plus_attention', reuse=tf.AUTO_REUSE): # 首先进行 k-NN 建模,找到每个点的 k 个最近邻点 # 根据每个点与其 k 个最近邻点的距离,计算点之间的权重 dists, idxs = knn(k, x) # 将点特征和最近邻点特征进行拼接 grouped_points = group(x, idxs) grouped_points = tf.concat([x, grouped_points], axis=-1) # 对拼接后的特征进行全连接网络处理 for i, num_output_channels in enumerate(mlp): grouped_points = tf_util.conv1d(grouped_points, num_output_channels, 1, 'mlp_%d' % i, is_training=is_training) # 对每个点和其最近邻点进行注意力权重计算 attention_points = [] for i in range(num_points): center_point = tf.expand_dims(tf.expand_dims(x[:, i, :], axis=1), axis=1) neighbor_points = tf.gather_nd(grouped_points, idxs[:, i, :], batch_dims=1) attention_feature = get_attention_feature(center_point, neighbor_points, num_dims * 2) attention_points.append(tf.reduce_sum(attention_feature, axis=1, keep_dims=True)) # 将注意力特征向量拼接起来,作为输出结果 output = tf.concat(attention_points, axis=1) return output ``` 在这个代码中,我们使用了 `get_attention_weight` 函数来获取注意力权重,并使用 `get_attention_feature` 函数来获取注意力特征向量。在 PointNet++ 加注意力机制改进中,我们对每个点和其 k 个最近邻点计算了注意力权重,然后用注意力权重加权求和得到了注意力特征向量,最后将所有注意力特征向量拼接起来作为输出结果。 请注意,这只是一个简单的示例,实际上,PointNet++ 加注意力机制改进的实现要比这个复杂得多。如果您需要更复杂的实现,建议参考相关论文或其他开源实现。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值