YOLOv8v10专栏限时99元订阅链接:限时99元去b站关注:AI缝合怪订阅YOLOv8v10 创新改进高效涨点+持续改进300多篇
(订阅的小伙伴,终身免费享有后续YOLOv11或是其他版本的改进专栏)
目录
一、Bi-level Routing Attention(BRA)模块介绍
四、创建涨点配置文件yolov10_BiLevelRoutingAttention.yaml
一、Bi-level Routing Attention(BRA)模块介绍
论文地址:论文地址点击此处即可跳转
代码地址:代码地址点击此处即可跳转
摘要:作为 Vision Transformer 的核心构建块,注意力是捕获长期依赖性的强大工具。然而,这种能力是有代价的:它会产生巨大的计算负担和沉重的内存占用,因为要计算所有空间位置的成对令牌交互。一系列作品试图通过将手工制作和内容无关的稀疏性引入 attention 来缓解这个问题,例如将 attention 操作限制在局部窗口、轴向条纹或扩张窗口内。与这些方法相反,我们提出了一种通过双层路由的新型动态稀疏注意力,以实现具有内容感知的更灵活的计算分配。具体来说,对于查询,首先在粗略区域级别过滤掉不相关的键值对,然后在剩余候选区域(即路由区域)的联合中应用细粒度的 token-to-token 注意。我们提供了一种简单而有效的双级路由注意实现,它利用稀疏性来节省计算和内存,同时只涉及 GPU 友好的密集矩阵乘法。根据提出的双层路由注意构建,然后提出了一个名为 BiFormer 的新通用视觉转换器。由于 BiFormer 以查询自适应的方式关注一小部分相关标记,而不会分散其他不相关标记的注意力,因此它既具有良好的性能又具有很高的计算效率,尤其是在密集的预测任务中。图像分类、对象检测和语义分割等多项计算机视觉任务的经验结果验证了我们设计的有效性。
原版注意力及其稀疏变体如下图:
(a) 原始注意力:全局操作,会产生高计算复杂度和大内存占用。
(b)-(d) 稀疏注意力:为了减少注意力的复杂度,一些方法引入了稀疏模式,如局部窗口、轴向条纹和扩张窗口。这些模式将注意力限制在特定区域,减少了考虑的键-值对数量。
(e) 可变形注意力:可变形注意力通过改变规则网格来实现图像自适应的稀疏性。这使得注意力机制可以集中关注输入图像的不同区域。
(f) 双层路由注意力:所提出的方法通过双层路由实现了动态的、查询感知的稀疏性。首先确定了前k个(本例中k=3)相关区域,然后关注它们的并集。这使得注意力机制能够根据每个查询自适应地关注最有语义相关的键-值对,从而实现高效的计算。
BRA注意力模块结构图:
BRA (Bi-Level Routing Attention) 模块是一个用于提升注意力机制计算效率的模块,特别适用于视觉Transformer。其原理主要通过引入一种动态、基于查询的稀疏注意力机制,来有效降低传统多头自注意力机制(MHSA)的计算成本和内存占用。
BRA模块的核心原理与作用
-
动态稀疏注意力机制:BRA模块通过动态调整注意力计算的范围,采用稀疏矩阵计算的方式,将注意力的计算重点放在最相关的局部区域,而不是像传统注意力机制那样在所有区域进行全局计算。这不仅提高了计算效率,还降低了不必要的内存占用。
-
双级路由机制:BRA通过一个双级的路由机制,首先通过一个较为粗略的注意力图筛选出最相关的区域,然后在这些区域内进行更加精细的注意力计算。这种层次化的处理方式极大地优化了计算过程,尤其是在处理高分辨率图像时,效果尤为显著。
-
提高效率与保持精度:尽管BRA在计算时减少了全局的关注范围,但通过其智能化的路由机制,保证了对关键区域的精确关注,确保了在提升计算效率的同时,模型性能不会显著下降,甚至在一些任务中能表现得更好。
- 提升计算效率:通过减少不必要的全局计算,BRA在处理大规模视觉数据时,能够大幅度降低计算复杂度,减少内存占用。
- 应用于视觉Transformer:BRA模块可以直接嵌入到Transformer中,替代传统的自注意力机制,使其在视觉任务中,尤其是在高分辨率图像处理时,具备更高的计算效率和更低的资源需求。
- 适用多种视觉任务:BRA模块可以应用于多种视觉任务中,例如目标检测、图像分割等,尤其在处理高分辨率图像时,能够带来显著的性能提升。
二、BRA核心代码
from typing import Tuple, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from torch import Tensor, LongTensor
__all__ = ['BiLevelRoutingAttention']
class TopkRouting(nn.Module):
"""
differentiable topk routing with scaling
Args:
qk_dim: int, feature dimension of query and key
topk: int, the 'topk'
qk_scale: int or None, temperature (multiply) of softmax activation
with_param: bool, wether inorporate learnable params in routing unit
diff_routing: bool, wether make routing differentiable
soft_routing: bool, wether make output value multiplied by routing weights
"""
def __init__(self, qk_dim, topk=4, qk_scale=None, param_routing=False, diff_routing=False):
super().__init__()
self.topk = topk
self.qk_dim = qk_dim
self.scale = qk_scale or qk_dim ** -0.5
self.diff_routing = diff_routing
# TODO: norm layer before/after linear?
self.emb = nn.Linear(qk_dim, qk_dim) if param_routing else nn.Identity()
# routing activation
self.routing_act = nn.Softmax(dim=-1)
def forward(self, query: Tensor, key: Tensor) -> Tuple[Tensor]:
"""
Args:
q, k: (n, p^2, c) tensor
Return:
r_weight, topk_index: (n, p^2, topk) tensor
"""
if not self.diff_routing:
query, key = query.detach(), key.detach()
query_hat, key_hat = self.emb(query), self.emb(key) # per-window pooling -> (n, p^2, c)
attn_logit = (query_hat * self.scale) @ key_hat.transpose(-2, -1) # (n, p^2, p^2)
topk_attn_logit, topk_index = torch.topk(attn_logit, k=self.topk, dim=-1) # (n, p^2, k), (n, p^2, k)
r_weight = self.routing_act(topk_attn_logit) # (n, p^2, k)
return r_weight, topk_index
class KVGather(nn.Module):
def __init__(self, mul_weight='none'):
super().__init__()
assert mul_weight in ['none', 'soft', 'hard']
self.mul_weight = mul_weight
def forward(self, r_idx: Tensor, r_weight: Tensor, kv: Tensor):
"""
r_idx: (n, p^2, topk) tensor
r_weight: (n, p^2, topk) tensor
kv: (n, p^2, w^2, c_kq+c_v)
Return:
(n, p^2, topk, w^2, c_kq+c_v) tensor
"""
# select kv according to routing index
n, p2, w2, c_kv = kv.size()
topk = r_idx.size(-1)
# print(r_idx.size(), r_weight.size())
# FIXME: gather consumes much memory (topk times redundancy), write cuda kernel?
topk_kv = torch.gather(kv.view(n, 1, p2, w2, c_kv).expand(-1, p2, -1, -1, -1),
# (n, p^2, p^2, w^2, c_kv) without mem cpy
dim=2,
index=r_idx.view(n, p2, topk, 1, 1).expand(-1, -1, -1, w2, c_kv)
# (n, p^2, k, w^2, c_kv)
)
if self.mul_weight == 'soft':
topk_kv = r_weight.view(n, p2, topk, 1, 1) * topk_kv # (n, p^2, k, w^2, c_kv)
elif self.mul_weight == 'hard':
raise NotImplementedError('differentiable hard routing TBA')
# else: #'none'
# topk_kv = topk_kv # do nothing
return topk_kv
class QKVLinear(nn.Module):
def __init__(self, dim, qk_dim, bias=True):
super().__init__()
self.dim = dim
self.qk_dim = qk_dim
self.qkv = nn.Linear(dim, qk_dim + qk_dim + dim, bias=bias)
def forward(self, x):
q, kv = self.qkv(x).split([self.qk_dim, self.qk_dim + self.dim], dim=-1)
return q, kv
# q, k, v = self.qkv(x).split([self.qk_dim, self.qk_dim, self.dim], dim=-1)
# return q, k, v
class BiLevelRoutingAttention(nn.Module):
"""
n_win: number of windows in one side (so the actual number of windows is n_win*n_win)
kv_per_win: for kv_downsample_mode='ada_xxxpool' only, number of key/values per window. Similar to n_win, the actual number is kv_per_win*kv_per_win.
topk: topk for window filtering
param_attention: 'qkvo'-linear for q,k,v and o, 'none': param free attention
param_routing: extra linear for routing
diff_routing: wether to set routing differentiable
soft_routing: wether to multiply soft routing weights
"""
def __init__(self, dim, n_win=7, num_heads=8, qk_dim=None, qk_scale=None,
kv_per_win=4, kv_downsample_ratio=4, kv_downsample_kernel=None, kv_downsample_mode='identity',
topk=4, param_attention="qkvo", param_routing=False, diff_routing=False, soft_routing=False,
side_dwconv=3,
auto_pad=True):
super().__init__()
# local attention setting
self.dim = dim
self.n_win = n_win # Wh, Ww
self.num_heads = num_heads
self.qk_dim = qk_dim or dim
assert self.qk_dim % num_heads == 0 and self.dim % num_heads == 0, 'qk_dim and dim must be divisible by num_heads!'
self.scale = qk_scale or self.qk_dim ** -0.5
################side_dwconv (i.e. LCE in ShuntedTransformer)###########
self.lepe = nn.Conv2d(dim, dim, kernel_size=side_dwconv, stride=1, padding=side_dwconv // 2,
groups=dim) if side_dwconv > 0 else \
lambda x: torch.zeros_like(x)
################ global routing setting #################
self.topk = topk
self.param_routing = param_routing
self.diff_routing = diff_routing
self.soft_routing = soft_routing
# router
assert not (self.param_routing and not self.diff_routing) # cannot be with_param=True and diff_routing=False
self.router = TopkRouting(qk_dim=self.qk_dim,
qk_scale=self.scale,
topk=self.topk,
diff_routing=self.diff_routing,
param_routing=self.param_routing)
if self.soft_routing: # soft routing, always diffrentiable (if no detach)
mul_weight = 'soft'
elif self.diff_routing: # hard differentiable routing
mul_weight = 'hard'
else: # hard non-differentiable routing
mul_weight = 'none'
self.kv_gather = KVGather(mul_weight=mul_weight)
# qkv mapping (shared by both global routing and local attention)
self.param_attention = param_attention
if self.param_attention == 'qkvo':
self.qkv = QKVLinear(self.dim, self.qk_dim)
self.wo = nn.Linear(dim, dim)
elif self.param_attention == 'qkv':
self.qkv = QKVLinear(self.dim, self.qk_dim)
self.wo = nn.Identity()
else:
raise ValueError(f'param_attention mode {self.param_attention} is not surpported!')
self.kv_downsample_mode = kv_downsample_mode
self.kv_per_win = kv_per_win
self.kv_downsample_ratio = kv_downsample_ratio
self.kv_downsample_kenel = kv_downsample_kernel
if self.kv_downsample_mode == 'ada_avgpool':
assert self.kv_per_win is not None
self.kv_down = nn.AdaptiveAvgPool2d(self.kv_per_win)
elif self.kv_downsample_mode == 'ada_maxpool':
assert self.kv_per_win is not None
self.kv_down = nn.AdaptiveMaxPool2d(self.kv_per_win)
elif self.kv_downsample_mode == 'maxpool':
assert self.kv_downsample_ratio is not None
self.kv_down = nn.MaxPool2d(self.kv_downsample_ratio) if self.kv_downsample_ratio > 1 else nn.Identity()
elif self.kv_downsample_mode == 'avgpool':
assert self.kv_downsample_ratio is not None
self.kv_down = nn.AvgPool2d(self.kv_downsample_ratio) if self.kv_downsample_ratio > 1 else nn.Identity()
elif self.kv_downsample_mode == 'identity': # no kv downsampling
self.kv_down = nn.Identity()
elif self.kv_downsample_mode == 'fracpool':
# assert self.kv_downsample_ratio is not None
# assert self.kv_downsample_kenel is not None
# TODO: fracpool
# 1. kernel size should be input size dependent
# 2. there is a random factor, need to avoid independent sampling for k and v
raise NotImplementedError('fracpool policy is not implemented yet!')
elif kv_downsample_mode == 'conv':
# TODO: need to consider the case where k != v so that need two downsample modules
raise NotImplementedError('conv policy is not implemented yet!')
else:
raise ValueError(f'kv_down_sample_mode {self.kv_downsaple_mode} is not surpported!')
# softmax for local attention
self.attn_act = nn.Softmax(dim=-1)
self.auto_pad = auto_pad
def forward(self, x, ret_attn_mask=False):
"""
x: NHWC tensor
Return:
NHWC tensor
"""
x = rearrange(x, "n c h w -> n h w c")
# NOTE: use padding for semantic segmentation
###################################################
if self.auto_pad:
N, H_in, W_in, C = x.size()
pad_l = pad_t = 0
pad_r = (self.n_win - W_in % self.n_win) % self.n_win
pad_b = (self.n_win - H_in % self.n_win) % self.n_win
x = F.pad(x, (0, 0, # dim=-1
pad_l, pad_r, # dim=-2
pad_t, pad_b)) # dim=-3
_, H, W, _ = x.size() # padded size
else:
N, H, W, C = x.size()
assert H % self.n_win == 0 and W % self.n_win == 0 #
###################################################
# patchify, (n, p^2, w, w, c), keep 2d window as we need 2d pooling to reduce kv size
x = rearrange(x, "n (j h) (i w) c -> n (j i) h w c", j=self.n_win, i=self.n_win)
#################qkv projection###################
# q: (n, p^2, w, w, c_qk)
# kv: (n, p^2, w, w, c_qk+c_v)
# NOTE: separte kv if there were memory leak issue caused by gather
q, kv = self.qkv(x)
# pixel-wise qkv
# q_pix: (n, p^2, w^2, c_qk)
# kv_pix: (n, p^2, h_kv*w_kv, c_qk+c_v)
q_pix = rearrange(q, 'n p2 h w c -> n p2 (h w) c')
kv_pix = self.kv_down(rearrange(kv, 'n p2 h w c -> (n p2) c h w'))
kv_pix = rearrange(kv_pix, '(n j i) c h w -> n (j i) (h w) c', j=self.n_win, i=self.n_win)
q_win, k_win = q.mean([2, 3]), kv[..., 0:self.qk_dim].mean(
[2, 3]) # window-wise qk, (n, p^2, c_qk), (n, p^2, c_qk)
##################side_dwconv(lepe)##################
# NOTE: call contiguous to avoid gradient warning when using ddp
lepe = self.lepe(rearrange(kv[..., self.qk_dim:], 'n (j i) h w c -> n c (j h) (i w)', j=self.n_win,
i=self.n_win).contiguous())
lepe = rearrange(lepe, 'n c (j h) (i w) -> n (j h) (i w) c', j=self.n_win, i=self.n_win)
############ gather q dependent k/v #################
r_weight, r_idx = self.router(q_win, k_win) # both are (n, p^2, topk) tensors
kv_pix_sel = self.kv_gather(r_idx=r_idx, r_weight=r_weight, kv=kv_pix) # (n, p^2, topk, h_kv*w_kv, c_qk+c_v)
k_pix_sel, v_pix_sel = kv_pix_sel.split([self.qk_dim, self.dim], dim=-1)
# kv_pix_sel: (n, p^2, topk, h_kv*w_kv, c_qk)
# v_pix_sel: (n, p^2, topk, h_kv*w_kv, c_v)
######### do attention as normal ####################
k_pix_sel = rearrange(k_pix_sel, 'n p2 k w2 (m c) -> (n p2) m c (k w2)',
m=self.num_heads) # flatten to BMLC, (n*p^2, m, topk*h_kv*w_kv, c_kq//m) transpose here?
v_pix_sel = rearrange(v_pix_sel, 'n p2 k w2 (m c) -> (n p2) m (k w2) c',
m=self.num_heads) # flatten to BMLC, (n*p^2, m, topk*h_kv*w_kv, c_v//m)
q_pix = rearrange(q_pix, 'n p2 w2 (m c) -> (n p2) m w2 c',
m=self.num_heads) # to BMLC tensor (n*p^2, m, w^2, c_qk//m)
# param-free multihead attention
attn_weight = (
q_pix * self.scale) @ k_pix_sel # (n*p^2, m, w^2, c) @ (n*p^2, m, c, topk*h_kv*w_kv) -> (n*p^2, m, w^2, topk*h_kv*w_kv)
attn_weight = self.attn_act(attn_weight)
out = attn_weight @ v_pix_sel # (n*p^2, m, w^2, topk*h_kv*w_kv) @ (n*p^2, m, topk*h_kv*w_kv, c) -> (n*p^2, m, w^2, c)
out = rearrange(out, '(n j i) m (h w) c -> n (j h) (i w) (m c)', j=self.n_win, i=self.n_win,
h=H // self.n_win, w=W // self.n_win)
out = out + lepe
# output linear
out = self.wo(out)
# NOTE: use padding for semantic segmentation
# crop padded region
if self.auto_pad and (pad_r > 0 or pad_b > 0):
out = out[:, :H_in, :W_in, :].contiguous()
if ret_attn_mask:
return out, r_weight, r_idx, attn_weight
else:
return rearrange(out, "n h w c -> n c h w")
三、手把手教你添加BRA模块和修改task.py文件
1.首先在yolov10/ultralytics/nn/newsAddmodules创建一个.py文件
2.在yolov10/ultralytics/nn/newsAddmodules/__init__.py中引用
3.修改task.py文件 :在task.py中找到这个参数方法 def parse_model(d, ch, verbose=True):
elif m in (BiLevelRoutingAttention,):
args = [ch[f], *args]
四、创建涨点配置文件yolov10_BiLevelRoutingAttention.yaml
# Parameters
nc: 80 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
# [depth, width, max_channels]
n: [0.33, 0.25, 1024]
# YOLOv8.0n backbone
backbone:
# [from, repeats, module, args]
- [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
- [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
- [-1, 3, C2f, [128, True]]
- [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
- [-1, 6, C2f, [256, True]]
- [-1, 1, SCDown, [512, 3, 2]] # 5-P4/16
- [-1, 6, C2f, [512, True]]
- [-1, 1, SCDown, [1024, 3, 2]] # 7-P5/32
- [-1, 3, C2f, [1024, True]]
- [-1, 1, SPPF, [1024, 5]] # 9
- [-1, 1, PSA, [1024]] # 10
# YOLOv8.0n head
head:
- [-1, 1, nn.Upsample, [None, 2, "nearest"]]
- [[-1, 6], 1, Concat, [1]] # cat backbone P4
- [-1, 3, C2f, [512]] # 13
- [-1, 1, nn.Upsample, [None, 2, "nearest"]]
- [[-1, 4], 1, Concat, [1]] # cat backbone P3
- [-1, 3, C2f, [256]] # 16 (P3/8-small)
- [-1,1,BiLevelRoutingAttention,[]] #17
- [-1, 1, Conv, [256, 3, 2]]
- [[-1, 13], 1, Concat, [1]] # cat head P4
- [-1, 3, C2f, [512]] #20 (P4/16-medium)
- [-1,1,BiLevelRoutingAttention,[]] #21
- [-1, 1, SCDown, [512, 3, 2]]
- [[-1, 10], 1, Concat, [1]] # cat head P5
- [-1, 3, C2fCIB, [1024, True, True]] # 22 (P5/32-large)
- [[17, 21, 24], 1, v10Detect, [nc]] # Detect(P3, P4, P5)
五、模型训练,检验是否可以正常运行
from ultralytics import YOLOv10
import warnings
warnings.filterwarnings('ignore')
# 模型配置文件
model_yaml_path = r"E:\yolo\yolov10\ultralytics\cfg\models\v10\yolov10_BiLevelRoutingAttention.yaml"
#数据集配置文件
data_yaml_path = r'E:\yolo\yolov10\datasets\data.yaml'
if __name__ == '__main__':
model = YOLOv10(model_yaml_path)
#训练模型
results = model.train(data=data_yaml_path,
imgsz=256,
epochs=10,
batch=4,
workers=0,
optimizer='SGD', # using SGD
amp=False, # 如果出现训练损失为Nan可以关闭amp
project='runs/V10train',
name='exp',
)
六、本文总结
到此本文的正式分享内容就结束了,在这里给大家推荐我的YOLOv10改进有效涨点专栏,本专栏后期我会根据各种最新的前沿顶会进行论文复现,也会对一些老的改进机制进行补充,本专栏会持续更新300+创新改进点,目前限时特价99.9,仅限前66名,之后恢复原价!!!大家尽早关注有效涨点专栏,带着大家快速高效发论文!如果大家觉得本文能帮助到你了,订阅本专栏,关注后续更多的更新~