选择过程,最先使用的BIformer,but,显存炸了,搜索发现BI方法就是存在内存占用大计算代价高的特点,所以选了localwindowsattention
那么如何在yolov8中加入注意力呢,和把大象装进冰箱的步骤一样
第一步:加入注意力代码,代码位置nn-modules-conv.py
最上方添加
from torch import Tensor
from typing import Tuple
from einops import rearrange#没有下载的需要先下载
import torch.nn.functional as F
import itertools
然后,在代码最下方添加注意力代码,源码放在下方,注意:conv2d_BN也需要
class Conv2d_BN(torch.nn.Sequential):
def __init__(self, a, b, ks=1, stride=1, pad=0, dilation=1,
groups=1, bn_weight_init=1, resolution=-10000):
super().__init__()
self.add_module('c', torch.nn.Conv2d(
a, b, ks, stride, pad, dilation, groups, bias=False))
self.add_module('bn', torch.nn.BatchNorm2d(b))
torch.nn.init.constant_(self.bn.weight, bn_weight_init)
torch.nn.init.constant_(self.bn.bias, 0)
@torch.no_grad()
def switch_to_deploy(self):
c, bn = self._modules.values()
w = bn.weight / (bn.running_var + bn.eps)**0.5
w = c.weight * w[:, None, None, None]
b = bn.bias - bn.running_mean * bn.weight / \
(bn.running_var + bn.eps)**0.5
m = torch.nn.Conv2d(w.size(1) * self.c.groups, w.size(
0), w.shape[2:], stride=self.c.stride, padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups)
m.weight.data.copy_(w)
m.bias.data.copy_(b)
return m
class CascadedGroupAttention(torch.nn.Module):
r""" Cascaded Group Attention.
Args:
dim (int): Number of input channels.
key_dim (int): The dimension for query and key.
num_heads (int): Number of attention heads.
attn_ratio (int): Multiplier for the query dim for value dimension.
resolution (int): Input resolution, correspond to the window size.
kernels (List[int]): The kernel size of the dw conv on query.
"""
def __init__(self, dim, key_dim, num_heads=4,
attn_ratio=4,
resolution=14,
kernels=[5, 5, 5, 5]):
super().__init__()
self.num_heads = num_heads
self.scale = key_dim ** -0.5
self.key_dim = key_dim
self.d = dim // num_heads
self.attn_ratio = attn_ratio
qkvs = []
dws = []
for i in range(num_heads):
qkvs.append(Conv2d_BN(dim // (num_heads), self.key_dim * 2 + self.d, resolution=resolution))
dws.append(Conv2d_BN(self.key_dim, self.key_dim, kernels[i], 1, kernels[i]//2, groups=self.key_dim, resolution=resolution))
self.qkvs = torch.nn.ModuleList(qkvs)
self.dws = torch.nn.ModuleList(dws)
self.proj = torch.nn.Sequential(torch.nn.ReLU(), Conv2d_BN(
self.d * num_heads, dim, bn_weight_init=0, resolution=resolution))
points = list(itertools.product(range(resolution), range(resolution)))
N = len(points)
attention_offsets = {}
idxs = []
for p1 in points:
for p2 in points:
offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1]))
if offset not in attention_offsets:
attention_offsets[offset] = len(attention_offsets)
idxs.append(attention_offsets[offset])
self.attention_biases = torch.nn.Parameter(
torch.zeros(num_heads, len(attention_offsets)))
self.register_buffer('attention_bias_idxs', torch.LongTensor(idxs).view(N, N))
@torch.no_grad()
def train(self, mode=True):
super().train(mode)
if mode and hasattr(self, 'ab'):
del self.ab
else:
self.ab = self.attention_biases[:, self.attention_bias_idxs]
def forward(self, x): # x (B,C,H,W)
B, C, H, W = x.shape
trainingab = self.attention_biases[:, self.attention_bias_idxs]
feats_in = x.chunk(len(self.qkvs), dim=1)
feats_out = []
feat = feats_in[0]
for i, qkv in enumerate(self.qkvs):
if i > 0: # add the previous output to the input
feat = feat + feats_in[i]
feat = qkv(feat)
q, k, v = feat.view(B, -1, H, W).split([self.key_dim, self.key_dim, self.d], dim=1) # B, C/h, H, W
q = self.dws[i](q)
q, k, v = q.flatten(2), k.flatten(2), v.flatten(2) # B, C/h, N
attn = (
(q.transpose(-2, -1) @ k) * self.scale
+
(trainingab[i] if self.training else self.ab[i])
)
attn = attn.softmax(dim=-1) # BNN
feat = (v @ attn.transpose(-2, -1)).view(B, self.d, H, W) # BCHW
feats_out.append(feat)
x = self.proj(torch.cat(feats_out, 1))
return x
class LocalWindowAttention(torch.nn.Module):
r""" Local Window Attention.
Args:
dim (int): Number of input channels.
key_dim (int): The dimension for query and key.
num_heads (int): Number of attention heads.
attn_ratio (int): Multiplier for the query dim for value dimension.
resolution (int): Input resolution.
window_resolution (int): Local window resolution.
kernels (List[int]): The kernel size of the dw conv on query.
"""
def __init__(self, dim, key_dim=16, num_heads=4,
attn_ratio=4,
resolution=14,
window_resolution=7,
kernels=[5, 5, 5, 5]):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.resolution = resolution
assert window_resolution > 0, 'window_size must be greater than 0'
self.window_resolution = window_resolution
self.attn = CascadedGroupAttention(dim, key_dim, num_heads,
attn_ratio=attn_ratio,
resolution=window_resolution,
kernels=kernels)
def forward(self, x):
B, C, H, W = x.shape
if H <= self.window_resolution and W <= self.window_resolution:
x = self.attn(x)
else:
x = x.permute(0, 2, 3, 1)
pad_b = (self.window_resolution - H %
self.window_resolution) % self.window_resolution
pad_r = (self.window_resolution - W %
self.window_resolution) % self.window_resolution
padding = pad_b > 0 or pad_r > 0
if padding:
x = torch.nn.functional.pad(x, (0, 0, 0, pad_r, 0, pad_b))
pH, pW = H + pad_b, W + pad_r
nH = pH // self.window_resolution
nW = pW // self.window_resolution
# window partition, BHWC -> B(nHh)(nWw)C -> BnHnWhwC -> (BnHnW)hwC -> (BnHnW)Chw
x = x.view(B, nH, self.window_resolution, nW, self.window_resolution, C).transpose(2, 3).reshape(
B * nH * nW, self.window_resolution, self.window_resolution, C
).permute(0, 3, 1, 2)
x = self.attn(x)
# window reverse, (BnHnW)Chw -> (BnHnW)hwC -> BnHnWhwC -> B(nHh)(nWw)C -> BHWC
x = x.permute(0, 2, 3, 1).view(B, nH, nW, self.window_resolution, self.window_resolution,
C).transpose(2, 3).reshape(B, pH, pW, C)
if padding:
x = x[:, :H, :W].contiguous()
x = x.permute(0, 3, 1, 2)
return x
第二步 :注册及引用注意力代码,
注册方法代码位置nn-modules-init.py,如下图所示:
调用方法代码位置nn-modules-tasks.py,如下图所示,
添加elif m in 下方,else前方添加以下代码
elif m in [BiLevelRoutingAttention]:
c2 = ch[f]
args = [c2, *args[0:]]
别忘了上方导入函数,如图:
这里还有一个小插曲,部署过程报错,出现问题“Unresolved reference 'ultralytics'”
解决方法:是注册表层数多了一层,把文件放外边一层就好了
第三步:瓶颈模块中添加,改ultralytics-cfg-models-v8-yolov8-test-yaml(从yolov8改进,如果从其他模型要做出相应修改)
在网络的主干添加注意力,如下图标红位置所示:
在代码中需要添加模块,且在之后引用过的层数都要+1,如下图所示:
ok,完活,下面我们来测试一下,运行以下训练代码
观察训练过程,如下图所示,就代表添加成功啦!
加油!!