1,本文介绍
CPA-Enhancer通过链式思考提示机制实现了对未知退化条件下图像的自适应增强,显著提升了物体检测性能。其插件式设计便于集成到现有检测框架中,并在物体检测及其他视觉任务中设立了新的性能标准,展现了广泛的应用潜力。
关于CPA-Enhancer的详细介绍可以看论文:https://arxiv.org/abs/2403.11220v3
本文将讲解如何将CPA-Enhancer融合进yolov8
话不多说,上代码!
2,将CPA-Enhancer融合进YOLOv8
2.1 步骤一
首先找到如下的目录'ultralytics/nn',然后在这个目录下创建一个'Addmodules'文件夹,然后在这个目录下创建一个Enhancer.py文件,文件名字可以根据你自己的习惯起,然后将CPA-Enhancer的核心代码复制进去。
import torch
import torch.nn as nn
import torch.nn.functional as F
import numbers
from einops import rearrange
from einops.layers.torch import Rearrange
__all__ = ['CPA_arch']
class RFAConv(nn.Module): # 基于Group Conv实现的RFAConv
def __init__(self, in_channel, out_channel, kernel_size=3, stride=1):
super().__init__()
self.kernel_size = kernel_size
self.get_weight = nn.Sequential(nn.AvgPool2d(kernel_size=kernel_size, padding=kernel_size // 2, stride=stride),
nn.Conv2d(in_channel, in_channel * (kernel_size ** 2), kernel_size=1,
groups=in_channel, bias=False))
self.generate_feature = nn.Sequential(
nn.Conv2d(in_channel, in_channel * (kernel_size ** 2), kernel_size=kernel_size, padding=kernel_size // 2,
stride=stride, groups=in_channel, bias=False),
nn.BatchNorm2d(in_channel * (kernel_size ** 2)),
nn.ReLU())
self.conv = nn.Sequential(nn.Conv2d(in_channel, out_channel, kernel_size=kernel_size, stride=kernel_size),
nn.BatchNorm2d(out_channel),
nn.ReLU())
def forward(self, x):
b, c = x.shape[0:2]
weight = self.get_weight(x)
h, w = weight.shape[2:]
weighted = weight.view(b, c, self.kernel_size ** 2, h, w).softmax(2) # b c*kernel**2,h,w -> b c k**2 h w
feature = self.generate_feature(x).view(b, c, self.kernel_size ** 2, h,
w) # b c*kernel**2,h,w -> b c k**2 h w 获得感受野空间特征
weighted_data = feature * weighted
conv_data = rearrange(weighted_data, 'b c (n1 n2) h w -> b c (h n1) (w n2)', n1=self.kernel_size,
# b c k**2 h w -> b c h*k w*k
n2=self.kernel_size)
return self.conv(conv_data)
class Downsample(nn.Module):
def __init__(self, n_feat):
super(Downsample, self).__init__()
self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat // 2, kernel_size=3, stride=1, padding=1, bias=False),
nn.PixelUnshuffle(2))
def forward(self, x):
return self.body(x)
class Upsample(nn.Module):
def __init__(self, n_feat):
super(Upsample, self).__init__()
self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat * 2, kernel_size=3, stride=1, padding=1, bias=False),
nn.PixelShuffle(2))
def forward(self, x): # (b,c,h,w)
return self.body(x) # (b,c/2,h*2,w*2)
class SpatialAttention(nn.Module):
def __init__(self):
super(SpatialAttention, self).__init__()
self.sa = nn.Conv2d(2, 1, 7, padding=3, padding_mode='reflect', bias=True)
def forward(self, x): # x:[b,c,h,w]
x_avg = torch.mean(x, dim=1, keepdim=True) # (b,1,h,w)
x_max, _ = torch.max(x, dim=1, keepdim=True) # (b,1,h,w)
x2 = torch.concat([x_avg, x_max], dim=1) # (b,2,h,w)
sattn &