一文读懂| 自注意力与交叉注意力机制在计算机视觉中作用与基本原理

《博主简介》

小伙伴们好,我是阿旭。专注于人工智能、AIGC、python、计算机视觉相关分享研究。
👍感谢小伙伴们点赞、关注!

《------往期经典推荐------》

一、AI应用软件开发实战专栏【链接】

项目名称项目名称
1.【人脸识别与管理系统开发2.【车牌识别与自动收费管理系统开发
3.【手势识别系统开发4.【人脸面部活体检测系统开发
5.【图片风格快速迁移软件开发6.【人脸表表情识别系统
7.【YOLOv8多目标识别与自动标注软件开发8.【基于YOLOv8深度学习的行人跌倒检测系统
9.【基于YOLOv8深度学习的PCB板缺陷检测系统10.【基于YOLOv8深度学习的生活垃圾分类目标检测系统
11.【基于YOLOv8深度学习的安全帽目标检测系统12.【基于YOLOv8深度学习的120种犬类检测与识别系统
13.【基于YOLOv8深度学习的路面坑洞检测系统14.【基于YOLOv8深度学习的火焰烟雾检测系统
15.【基于YOLOv8深度学习的钢材表面缺陷检测系统16.【基于YOLOv8深度学习的舰船目标分类检测系统
17.【基于YOLOv8深度学习的西红柿成熟度检测系统18.【基于YOLOv8深度学习的血细胞检测与计数系统
19.【基于YOLOv8深度学习的吸烟/抽烟行为检测系统20.【基于YOLOv8深度学习的水稻害虫检测与识别系统
21.【基于YOLOv8深度学习的高精度车辆行人检测与计数系统22.【基于YOLOv8深度学习的路面标志线检测与识别系统
23.【基于YOLOv8深度学习的智能小麦害虫检测识别系统24.【基于YOLOv8深度学习的智能玉米害虫检测识别系统
25.【基于YOLOv8深度学习的200种鸟类智能检测与识别系统26.【基于YOLOv8深度学习的45种交通标志智能检测与识别系统
27.【基于YOLOv8深度学习的人脸面部表情识别系统28.【基于YOLOv8深度学习的苹果叶片病害智能诊断系统
29.【基于YOLOv8深度学习的智能肺炎诊断系统30.【基于YOLOv8深度学习的葡萄簇目标检测系统
31.【基于YOLOv8深度学习的100种中草药智能识别系统32.【基于YOLOv8深度学习的102种花卉智能识别系统
33.【基于YOLOv8深度学习的100种蝴蝶智能识别系统34.【基于YOLOv8深度学习的水稻叶片病害智能诊断系统
35.【基于YOLOv8与ByteTrack的车辆行人多目标检测与追踪系统36.【基于YOLOv8深度学习的智能草莓病害检测与分割系统
37.【基于YOLOv8深度学习的复杂场景下船舶目标检测系统38.【基于YOLOv8深度学习的农作物幼苗与杂草检测系统
39.【基于YOLOv8深度学习的智能道路裂缝检测与分析系统40.【基于YOLOv8深度学习的葡萄病害智能诊断与防治系统
41.【基于YOLOv8深度学习的遥感地理空间物体检测系统42.【基于YOLOv8深度学习的无人机视角地面物体检测系统
43.【基于YOLOv8深度学习的木薯病害智能诊断与防治系统44.【基于YOLOv8深度学习的野外火焰烟雾检测系统
45.【基于YOLOv8深度学习的脑肿瘤智能检测系统46.【基于YOLOv8深度学习的玉米叶片病害智能诊断与防治系统
47.【基于YOLOv8深度学习的橙子病害智能诊断与防治系统48.【基于深度学习的车辆检测追踪与流量计数系统
49.【基于深度学习的行人检测追踪与双向流量计数系统50.【基于深度学习的反光衣检测与预警系统
51.【基于深度学习的危险区域人员闯入检测与报警系统52.【基于深度学习的高密度人脸智能检测与统计系统
53.【基于深度学习的CT扫描图像肾结石智能检测系统54.【基于深度学习的水果智能检测系统
55.【基于深度学习的水果质量好坏智能检测系统56.【基于深度学习的蔬菜目标检测与识别系统
57.【基于深度学习的非机动车驾驶员头盔检测系统58.【太基于深度学习的阳能电池板检测与分析系统
59.【基于深度学习的工业螺栓螺母检测60.【基于深度学习的金属焊缝缺陷检测系统
61.【基于深度学习的链条缺陷检测与识别系统62.【基于深度学习的交通信号灯检测识别
63.【基于深度学习的草莓成熟度检测与识别系统64.【基于深度学习的水下海生物检测识别系统
65.【基于深度学习的道路交通事故检测识别系统66.【基于深度学习的安检X光危险品检测与识别系统
67.【基于深度学习的农作物类别检测与识别系统68.【基于深度学习的危险驾驶行为检测识别系统
69.【基于深度学习的维修工具检测识别系统

二、机器学习实战专栏【链接】,已更新31期,欢迎关注,持续更新中~~
三、深度学习【Pytorch】专栏【链接】
四、【Stable Diffusion绘画系列】专栏【链接】
五、YOLOv8改进专栏【链接】持续更新中~~
六、YOLO性能对比专栏【链接】,持续更新中~

《------正文------》

引言

近年来,自注意力(self-sttention)交叉注意力(cross-attention)已经成为计算机视觉领域的强大机制,在vison transformers(ViTs)和各种生成模型等模型的成功中发挥了关键作用。这些注意力机制使模型能够专注于输入图像的不同部分,改善机器感知和处理视觉信息的方式。在本文中,我们将探讨什么是自注意力和交叉注意力,它们是如何工作的,以及它们在计算机视觉中的具体应用。

深度学习中的注意力是什么?

在高层次上,注意力机制允许模型以不同的方式衡量不同输入数据的重要性。注意力不是平等对待所有输入,而是帮助模型专注于与手头任务更相关的特定区域或特征。这首先是在自然语言处理(NLP)中引入的,但后来被应用于计算机视觉任务,彻底改变了深度学习模型理解图像的方式。

自注意力:关注内部关系

自注意力是一种机制,模型计算单个输入(通常是计算机视觉中的图像)的所有部分之间的关系。术语“自我”指的是模型关注同一输入中的关系,允许它捕获局部和全局依赖关系。

自注意力是如何工作的?

在自注意力中,对于每个像素(或图像中的补丁),模型计算该像素与其他像素之间的“注意力得分”。这些分数帮助模型决定在对任何特定部分进行预测时对图像的不同部分给予多少关注。

这里有一个简化的过程:

img

  1. Query、Key和Value:图像的每个部分都被投影到三个不同的向量中:Query(Q)、Key(K)和Value(V)。Query和Key用于计算注意力分数,而Value保存模型将向前传递的信息。
  2. 注意力分数计算:任何两个像素之间的注意力分数是通过取它们的Query和Key向量的点积来计算的,然后是一个softmax操作来规范化这些分数。
  3. 加权和:然后使用注意力分数来加权值向量。该模型输出每个像素的值向量的加权和,这使得它能够专注于图像中最相关的部分。

img

img

计算机视觉应用

自注意力是视觉转换器(ViTs)的核心,它将图像视为一系列补丁,而不是使用卷积。这种方法允许ViT捕获像素之间的长距离和短距离依赖关系,使其对于图像分类,对象检测和分割等任务非常有效。

交叉注意力:连接不同的模态

交叉注意力与自注意力的不同之处在于,它在两个不同的输入之间运行,而不是在一个输入内运行。在交叉注意力中,注意力机制允许模型基于来自另一个输入(例如文本提示或不同的图像)的信息来关注一个输入(例如图像)的相关部分。

交叉注意力是如何工作的?

与自我注意力类似,交叉注意力也使用Query、Key和Value向量,但这些向量来自两个不同的来源:

img

  1. 查询来自一个输入(例如,文本)。
  2. 键和值来自第二输入(例如,图像)。

然后,交叉注意力机制计算来自一个模态的查询和来自另一个模态的键之间的注意力分数。这允许模型关注一个输入中与另一个输入最相关的部分,从而弥合不同数据类型之间的差距。

img

img

计算机视觉应用

交叉注意是多模态模型的基本组成部分,其中模型必须处理多种类型的输入。例如,在文本到图像生成(例如,DALL·E和稳定扩散),交叉注意力使模型能够将文本描述与相关视觉特征对齐。通过允许文本引导模型关注图像的特定区域,交叉注意确保生成的图像与文本提示准确匹配。

另一个令人兴奋的应用是图像引导的图像合成视频生成,其中图像或一组图像可以引导新图像或帧的合成。在这里,交叉注意力帮助模型混合来自不同来源的信息,确保输出的一致性。

自注意力与交叉注意力主要区别

虽然自注意力和交叉注意力在现代计算机视觉模型中起着关键作用,但它们的区别在于它们所捕捉的关系的性质:

img

为什么注意力在计算机视觉中很重要

注意力机制,无论是自注意力还是交叉注意力,都提供了灵活性和可扩展性,使模型能够更有效地处理复杂的视觉任务。随着计算机视觉的不断发展,基于注意力的模型有望在广泛的应用中扩大其影响力,从自动驾驶汽车和医学成像到生成令人惊叹的视觉内容的创意工具。

自我注意力和交叉注意力都为模型提供了更智能地“看”世界的能力。自我注意使他们能够专注于图像中的重要细节,而交叉注意使他们能够联合收割机从多种模态中获得见解,从而对视觉信息产生更深入、更细致的理解。

结论

自注意力和交叉注意力是重新定义计算机视觉领域的变革机制。它们使机器能够更有效地理解视觉数据,无论是通过关注图像内的关系,还是将图像与文本等其他形式联系起来。随着注意力机制的不断成熟,它们在推动基于视觉的应用程序创新方面的潜力是无限的,使其成为人工智能未来的重要工具。


在这里插入图片描述

在这里插入图片描述

好了,这篇文章就介绍到这里,喜欢的小伙伴感谢给点个赞和关注,更多精彩内容持续更新~~
关于本篇文章大家有任何建议或意见,欢迎在评论区留言交流!

### 交叉注意力机制的实现 交叉注意力机制是一种扩展形式的自注意力机制,其核心思想是在两个不同的上下文中计算注意力权重。具体来说,查询(Query)来自一个序列,而键(Key)和值(Value)则来源于另一个序列。以下是基于 PyTorch 和 TensorFlow 的交叉注意力机制实现代码。 #### 使用 PyTorch 实现交叉注意力机制 ```python import torch import torch.nn as nn import math class CrossAttention(nn.Module): def __init__(self, dim_model, num_heads, dropout=0.1): super(CrossAttention, self).__init__() assert dim_model % num_heads == 0, "dim_model must be divisible by num_heads" self.num_heads = num_heads self.head_dim = dim_model // num_heads self.query_projection = nn.Linear(dim_model, dim_model) self.key_projection = nn.Linear(dim_model, dim_model) self.value_projection = nn.Linear(dim_model, dim_model) self.dropout = nn.Dropout(dropout) self.out_projection = nn.Linear(dim_model, dim_model) def forward(self, query, key_value_input, mask=None): batch_size, seq_len_q, _ = query.size() _, seq_len_kv, _ = key_value_input.size() # Linear projections queries = self.query_projection(query).view(batch_size, seq_len_q, self.num_heads, self.head_dim).transpose(1, 2) keys = self.key_projection(key_value_input).view(batch_size, seq_len_kv, self.num_heads, self.head_dim).transpose(1, 2) values = self.value_projection(key_value_input).view(batch_size, seq_len_kv, self.num_heads, self.head_dim).transpose(1, 2) # Scaled dot-product attention scores = torch.matmul(queries, keys.transpose(-2, -1)) / math.sqrt(self.head_dim) # (batch_size, heads, seq_len_q, seq_len_kv) if mask is not None: scores = scores.masked_fill(mask == 0, float('-inf')) attn_weights = torch.softmax(scores, dim=-1) attn_weights = self.dropout(attn_weights) context = torch.matmul(attn_weights, values) # (batch_size, heads, seq_len_q, head_dim) context = context.transpose(1, 2).contiguous().view(batch_size, seq_len_q, -1) # (batch_size, seq_len_q, dim_model) output = self.out_projection(context) return output, attn_weights ``` 上述代码实现了交叉注意力模块,其中 `query` 是源序列,`key_value_input` 是目标序列[^1]。 --- #### 使用 TensorFlow 实现交叉注意力机制 ```python import tensorflow as tf def scaled_dot_product_attention(q, k, v, mask=None): matmul_qk = tf.matmul(q, k, transpose_b=True) # (..., seq_len_q, seq_len_k) dk = tf.cast(tf.shape(k)[-1], tf.float32) scaled_attention_logits = matmul_qk / tf.math.sqrt(dk) if mask is not None: scaled_attention_logits += (mask * -1e9) attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1) # (..., seq_len_q, seq_len_k) output = tf.matmul(attention_weights, v) # (..., seq_len_q, depth_v) return output, attention_weights class MultiHeadCrossAttention(tf.keras.layers.Layer): def __init__(self, d_model, num_heads): super(MultiHeadCrossAttention, self).__init__() self.num_heads = num_heads self.d_model = d_model assert d_model % self.num_heads == 0 self.depth = d_model // self.num_heads self.wq = tf.keras.layers.Dense(d_model) self.wk = tf.keras.layers.Dense(d_model) self.wv = tf.keras.layers.Dense(d_model) self.dense = tf.keras.layers.Dense(d_model) def split_heads(self, x, batch_size): """Split the last dimension into (num_heads, depth). Transpose the result such that the shape is (batch_size, num_heads, seq_len, depth) """ x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth)) return tf.transpose(x, perm=[0, 2, 1, 3]) def call(self, q, kv, mask=None): batch_size = tf.shape(q)[0] q = self.wq(q) # (batch_size, seq_len_q, d_model) k = self.wk(kv) # (batch_size, seq_len_kv, d_model) v = self.wv(kv) # (batch_size, seq_len_kv, d_model) q = self.split_heads(q, batch_size) # (batch_size, num_heads, seq_len_q, depth) k = self.split_heads(k, batch_size) # (batch_size, num_heads, seq_len_kv, depth) v = self.split_heads(v, batch_size) # (batch_size, num_heads, seq_len_kv, depth) scaled_attention, attention_weights = scaled_dot_product_attention( q, k, v, mask) scaled_attention = tf.transpose(scaled_attention, perm=[0, 2, 1, 3]) # (batch_size, seq_len_q, num_heads, depth) concat_attention = tf.reshape(scaled_attention, (batch_size, -1, self.d_model)) # (batch_size, seq_len_q, d_model) output = self.dense(concat_attention) # (batch_size, seq_len_q, d_model) return output, attention_weights ``` 此代码展示了如何在 TensorFlow 中构建一个多头交叉注意力层,允许不同序列间的交互[^3]。 --- ### 参数调整建议 为了优化交叉注意力机制的表现,可以尝试以下方法: - 调整模型维度 (`d_model`) 和头部数量 (`num_heads`) 来适应特定任务的需求[^4]。 - 应用正则化技术(如 Dropout 或 L2 正则化)防止过拟合。 - 对于大规模数据集,可采用稀疏注意力或其他高效变体减少计算开销。 ---
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

阿_旭

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值