一、导言
激活函数在目标检测中的作用至关重要,它们主要服务于以下几个关键目的:
-
引入非线性:神经网络的基本构建块(如卷积层、全连接层等)本质上是线性变换,而激活函数通过引入非线性,使得网络能够学习和表达更复杂、更丰富的数据特征。这对于目标检测任务尤为重要,因为目标可能出现在图像中的任何位置、大小和姿态,且彼此之间可能有重叠或遮挡,非线性表达能力可以帮助模型更好地理解和区分这些复杂的场景。
-
控制梯度流:激活函数的形状影响着反向传播过程中的梯度传递,这对于权重更新和学习过程至关重要。例如,ReLU及其变体(如Leaky ReLU、PReLU、FReLU等)通过在负值区域保持非零斜率,解决了传统ReLU可能导致的“神经元死亡”问题,从而促进了深层网络中的梯度流动。
-
增强模型表达能力:特定的激活函数能够提升模型在特定任务上的表现。例如,FReLU(Funnel Activation Function)通过在激活阶段整合空间信息,提高了模型的空间理解能力,这对于目标检测这种需要精确定位的任务非常有利。
-
影响计算效率:不同的激活函数具有不同的计算复杂度。在实时目标检测系统如YOLO系列中,选择计算成本低且效果好的激活函数(如ReLU)对于保证模型的运行速度和资源效率是必要的。
-
输出范围调整:某些激活函数(如Sigmoid和Softmax)能够将输出限制在特定范围内,这在输出层特别有用,比如将网络输出转化为概率值,便于进行目标类别预测。
综上所述,激活函数不仅决定了神经网络的学习能力,还在很大程度上影响了目标检测模型的精度、训练效率以及最终的检测性能。因此,在设计目标检测网络时,精心选择和设计激活函数是一个重要环节。
二、YOLO训练中常见且有效的激活函数
-
SiLU (Sigmoid Linear Unit): 也称为Swish,是一种自适应激活函数。SiLU尝试结合了线性变换和sigmoid函数的优点,能够提升模型的非线性表达能力,同时缓解梯度消失问题。
-
ReLU (Rectified Linear Unit): 是最常用的激活函数之一,当输入为正时,输出等于输入;为负时,输出为0。ReLU解决了sigmoid和tanh函数的梯度饱和问题,加速了神经网络的训练,但在负值区域梯度为0,可能导致“死亡ReLU”现象。
-
Leaky ReLU: 为了解决ReLU在负值区域梯度消失的问题而提出,即使负输入时函数也有非零斜率,帮助梯度流动。
-
FReLU (Fractional ReLU): 是ReLU的一个变种,它引入了一个可学习的参数来调整负输入部分的斜率,提供了比Leaky ReLU更灵活的调整能力。
-
PReLU (Parametric ReLU): 类似于Leaky ReLU,适用于不同层可能需要不同负斜率的情况。
-
Hardswish: 是MobileNetV3中引入的一种激活函数,试图模仿Swish但计算成本更低。它在移动端设备上表现高效且性能良好。
-
Mish: 由D. Misra提出,Mish结合了自我门控的性质和ReLU的简单性,被发现能在多种任务上提高模型性能。
-
ELU (Exponential Linear Unit): 目的是减少ReLU的偏差移位问题并加速学习过程。
-
CELU (Continuously Differentiable Exponential Linear Unit): 是ELU的一个连续可微分版本,旨在保持ELU的优点同时确保所有点的导数存在,适合需要严格平滑性的应用。
-
GELU (Gaussian Error Linear Unit): 形式较为复杂,与高斯分布的累积分布函数有关。GELU在Transformer等模型中表现优秀,因为它能更好地匹配神经网络中权重初始化的分布。
-
SELU (Scaled Exponential Linear Unit): 设计用于自归一化神经网络,SELU旨在确保网络的输出具有零均值和单位方差,从而简化训练过程中的归一化需求。
需要注意的是,YOLOv7、YOLOv5项目采用的默认激活函数为SiLU,而YOLOv7-tiny项目采用的激活函数为LeakyReLU,ResNet系列采用的激活函数则为ReLU,选择何种激活函数为自己的Baseline需要先查阅相关论文再下判断,以此为改进的对照方可得出结果。
同时,YOLOv5/v7的激活函数一般在models/activations.py下
三、YOLOv7-tiny改进工作
了解二后,在YOLOv7项目文件下的models文件夹下的common.py,采用ctrl+F搜索如下代码。
self.act = nn.SiLU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity())
将该一行替换为如下所示的激活函数即可。
self.act = nn.SiLU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity())
# self.act = nn.ReLU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity())
# self.act = nn.LeakyReLU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity())
# self.act = nn.Hardswish() if act is True else (act if isinstance(act, nn.Module) else nn.Identity())
# self.act = nn.Mish() if act is True else (act if isinstance(act, nn.Module) else nn.Identity())
# self.act = nn.ELU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity())
# self.act = nn.GELU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity())
# self.act = nn.SELU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity())
# self.act = nn.RReLU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity())
# self.act = nn.PReLU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity())
接下来,打开yolov7-tiny.yaml,修改为你需要的激活函数,如此处为SiLU
# parameters
nc: 80 # number of classes
depth_multiple: 1.0 # model depth multiple
width_multiple: 1.0 # layer channel multiple
# anchors
anchors:
- [10,13, 16,30, 33,23] # P3/8
- [30,61, 62,45, 59,119] # P4/16
- [116,90, 156,198, 373,326] # P5/32
# yolov7-tiny backbone
backbone:
# [from, number, module, args] c2, k=1, s=1, p=None, g=1, act=True
[[-1, 1, Conv, [32, 3, 2, None, 1, nn.SiLU()]], # 0-P1/2
[-1, 1, Conv, [64, 3, 2, None, 1, nn.SiLU()]], # 1-P2/4
[-1, 1, Conv, [32, 1, 1, None, 1, nn.SiLU()]],
[-2, 1, Conv, [32, 1, 1, None, 1, nn.SiLU()]],
[-1, 1, Conv, [32, 3, 1, None, 1, nn.SiLU()]],
[-1, 1, Conv, [32, 3, 1, None, 1, nn.SiLU()]],
[[-1, -2, -3, -4], 1, Concat, [1]],
[-1, 1, Conv, [64, 1, 1, None, 1, nn.SiLU()]], # 7
[-1, 1, MP, []], # 8-P3/8
[-1, 1, Conv, [64, 1, 1, None, 1, nn.SiLU()]],
[-2, 1, Conv, [64, 1, 1, None, 1, nn.SiLU()]],
[-1, 1, Conv, [64, 3, 1, None, 1, nn.SiLU()]],
[-1, 1, Conv, [64, 3, 1, None, 1, nn.SiLU()]],
[[-1, -2, -3, -4], 1, Concat, [1]],
[-1, 1, Conv, [128, 1, 1, None, 1, nn.SiLU()]], # 14
[-1, 1, MP, []], # 15-P4/16
[-1, 1, Conv, [128, 1, 1, None, 1, nn.SiLU()]],
[-2, 1, Conv, [128, 1, 1, None, 1, nn.SiLU()]],
[-1, 1, Conv, [128, 3, 1, None, 1, nn.SiLU()]],
[-1, 1, Conv, [128, 3, 1, None, 1, nn.SiLU()]],
[[-1, -2, -3, -4], 1, Concat, [1]],
[-1, 1, Conv, [256, 1, 1, None, 1, nn.SiLU()]], # 21
[-1, 1, MP, []], # 22-P5/32
[-1, 1, Conv, [256, 1, 1, None, 1, nn.SiLU()]],
[-2, 1, Conv, [256, 1, 1, None, 1, nn.SiLU()]],
[-1, 1, Conv, [256, 3, 1, None, 1, nn.SiLU()]],
[-1, 1, Conv, [256, 3, 1, None, 1, nn.SiLU()]],
[[-1, -2, -3, -4], 1, Concat, [1]],
[-1, 1, Conv, [512, 1, 1, None, 1, nn.SiLU()]], # 28
]
# yolov7-tiny head
head:
[[-1, 1, SPPF, [256]], # 29
[-1, 1, Conv, [128, 1, 1, None, 1, nn.SiLU()]],
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[21, 1, Conv, [128, 1, 1, None, 1, nn.SiLU()]], # route backbone P4
[[-1, -2], 1, Concat, [1]],
[-1, 1, Conv, [64, 1, 1, None, 1, nn.SiLU()]],
[-2, 1, Conv, [64, 1, 1, None, 1, nn.SiLU()]],
[-1, 1, Conv, [64, 3, 1, None, 1, nn.SiLU()]],
[-1, 1, Conv, [64, 3, 1, None, 1, nn.SiLU()]],
[[-1, -2, -3, -4], 1, Concat, [1]],
[-1, 1, Conv, [128, 1, 1, None, 1, nn.SiLU()]], # 39
[-1, 1, Conv, [64, 1, 1, None, 1, nn.SiLU()]],
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[14, 1, Conv, [64, 1, 1, None, 1, nn.SiLU()]], # route backbone P3
[[-1, -2], 1, Concat, [1]],
[-1, 1, Conv, [32, 1, 1, None, 1, nn.SiLU()]],
[-2, 1, Conv, [32, 1, 1, None, 1, nn.SiLU()]],
[-1, 1, Conv, [32, 3, 1, None, 1, nn.SiLU()]],
[-1, 1, Conv, [32, 3, 1, None, 1, nn.SiLU()]],
[[-1, -2, -3, -4], 1, Concat, [1]],
[-1, 1, Conv, [64, 1, 1, None, 1, nn.SiLU()]], # 49
[-1, 1, Conv, [128, 3, 2, None, 1, nn.SiLU()]],
[[-1, 39], 1, Concat, [1]],
[-1, 1, Conv, [64, 1, 1, None, 1, nn.SiLU()]],
[-2, 1, Conv, [64, 1, 1, None, 1, nn.SiLU()]],
[-1, 1, Conv, [64, 3, 1, None, 1, nn.SiLU()]],
[-1, 1, Conv, [64, 3, 1, None, 1, nn.SiLU()]],
[[-1, -2, -3, -4], 1, Concat, [1]],
[-1, 1, Conv, [128, 1, 1, None, 1, nn.SiLU()]], # 57
[-1, 1, Conv, [256, 3, 2, None, 1, nn.SiLU()]],
[[-1, 29], 1, Concat, [1]],
[-1, 1, Conv, [128, 1, 1, None, 1, nn.SiLU()]],
[-2, 1, Conv, [128, 1, 1, None, 1, nn.SiLU()]],
[-1, 1, Conv, [128, 3, 1, None, 1, nn.SiLU()]],
[-1, 1, Conv, [128, 3, 1, None, 1, nn.SiLU()]],
[[-1, -2, -3, -4], 1, Concat, [1]],
[-1, 1, Conv, [256, 1, 1, None, 1, nn.SiLU()]], # 65
[49, 1, Conv, [128, 3, 1, None, 1, nn.SiLU()]],
[57, 1, Conv, [256, 3, 1, None, 1, nn.SiLU()]],
[65, 1, Conv, [512, 3, 1, None, 1, nn.SiLU()]],
[[66, 67, 68], 1, IDetect, [nc, anchors]], # Detect(P3, P4, P5)
]
若为LeakyReLU,则改为:
# parameters
nc: 80 # number of classes
depth_multiple: 1.0 # model depth multiple
width_multiple: 1.0 # layer channel multiple
# anchors
anchors:
- [10,13, 16,30, 33,23] # P3/8
- [30,61, 62,45, 59,119] # P4/16
- [116,90, 156,198, 373,326] # P5/32
# yolov7-tiny backbone
backbone:
# [from, number, module, args] c2, k=1, s=1, p=None, g=1, act=True
[[-1, 1, Conv, [32, 3, 2, None, 1, nn.LeakyReLU(0.1)]], # 0-P1/2
[-1, 1, Conv, [64, 3, 2, None, 1, nn.LeakyReLU(0.1)]], # 1-P2/4
[-1, 1, Conv, [32, 1, 1, None, 1, nn.LeakyReLU(0.1)]],
[-2, 1, Conv, [32, 1, 1, None, 1, nn.LeakyReLU(0.1)]],
[-1, 1, Conv, [32, 3, 1, None, 1, nn.LeakyReLU(0.1)]],
[-1, 1, Conv, [32, 3, 1, None, 1, nn.LeakyReLU(0.1)]],
[[-1, -2, -3, -4], 1, Concat, [1]],
[-1, 1, Conv, [64, 1, 1, None, 1, nn.LeakyReLU(0.1)]], # 7
[-1, 1, MP, []], # 8-P3/8
[-1, 1, Conv, [64, 1, 1, None, 1, nn.LeakyReLU(0.1)]],
[-2, 1, Conv, [64, 1, 1, None, 1, nn.LeakyReLU(0.1)]],
[-1, 1, Conv, [64, 3, 1, None, 1, nn.LeakyReLU(0.1)]],
[-1, 1, Conv, [64, 3, 1, None, 1, nn.LeakyReLU(0.1)]],
[[-1, -2, -3, -4], 1, Concat, [1]],
[-1, 1, Conv, [128, 1, 1, None, 1, nn.LeakyReLU(0.1)]], # 14
[-1, 1, MP, []], # 15-P4/16
[-1, 1, Conv, [128, 1, 1, None, 1, nn.LeakyReLU(0.1)]],
[-2, 1, Conv, [128, 1, 1, None, 1, nn.LeakyReLU(0.1)]],
[-1, 1, Conv, [128, 3, 1, None, 1, nn.LeakyReLU(0.1)]],
[-1, 1, Conv, [128, 3, 1, None, 1, nn.LeakyReLU(0.1)]],
[[-1, -2, -3, -4], 1, Concat, [1]],
[-1, 1, Conv, [256, 1, 1, None, 1, nn.LeakyReLU(0.1)]], # 21
[-1, 1, MP, []], # 22-P5/32
[-1, 1, Conv, [256, 1, 1, None, 1, nn.LeakyReLU(0.1)]],
[-2, 1, Conv, [256, 1, 1, None, 1, nn.LeakyReLU(0.1)]],
[-1, 1, Conv, [256, 3, 1, None, 1, nn.LeakyReLU(0.1)]],
[-1, 1, Conv, [256, 3, 1, None, 1, nn.LeakyReLU(0.1)]],
[[-1, -2, -3, -4], 1, Concat, [1]],
[-1, 1, Conv, [512, 1, 1, None, 1, nn.LeakyReLU(0.1)]], # 28
]
# yolov7-tiny head
head:
[[-1, 1, SPPF, [256]], # 29
[-1, 1, Conv, [128, 1, 1, None, 1, nn.LeakyReLU(0.1)]],
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[21, 1, Conv, [128, 1, 1, None, 1, nn.LeakyReLU(0.1)]], # route backbone P4
[[-1, -2], 1, Concat, [1]],
[-1, 1, Conv, [64, 1, 1, None, 1, nn.LeakyReLU(0.1)]],
[-2, 1, Conv, [64, 1, 1, None, 1, nn.LeakyReLU(0.1)]],
[-1, 1, Conv, [64, 3, 1, None, 1, nn.LeakyReLU(0.1)]],
[-1, 1, Conv, [64, 3, 1, None, 1, nn.LeakyReLU(0.1)]],
[[-1, -2, -3, -4], 1, Concat, [1]],
[-1, 1, Conv, [128, 1, 1, None, 1, nn.LeakyReLU(0.1)]], # 39
[-1, 1, Conv, [64, 1, 1, None, 1, nn.LeakyReLU(0.1)]],
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[14, 1, Conv, [64, 1, 1, None, 1, nn.LeakyReLU(0.1)]], # route backbone P3
[[-1, -2], 1, Concat, [1]],
[-1, 1, Conv, [32, 1, 1, None, 1, nn.LeakyReLU(0.1)]],
[-2, 1, Conv, [32, 1, 1, None, 1, nn.LeakyReLU(0.1)]],
[-1, 1, Conv, [32, 3, 1, None, 1, nn.LeakyReLU(0.1)]],
[-1, 1, Conv, [32, 3, 1, None, 1, nn.LeakyReLU(0.1)]],
[[-1, -2, -3, -4], 1, Concat, [1]],
[-1, 1, Conv, [64, 1, 1, None, 1, nn.LeakyReLU(0.1)]], # 49
[-1, 1, Conv, [128, 3, 2, None, 1, nn.LeakyReLU(0.1)]],
[[-1, 39], 1, Concat, [1]],
[-1, 1, Conv, [64, 1, 1, None, 1, nn.LeakyReLU(0.1)]],
[-2, 1, Conv, [64, 1, 1, None, 1, nn.LeakyReLU(0.1)]],
[-1, 1, Conv, [64, 3, 1, None, 1, nn.LeakyReLU(0.1)]],
[-1, 1, Conv, [64, 3, 1, None, 1, nn.LeakyReLU(0.1)]],
[[-1, -2, -3, -4], 1, Concat, [1]],
[-1, 1, Conv, [128, 1, 1, None, 1, nn.LeakyReLU(0.1)]], # 57
[-1, 1, Conv, [256, 3, 2, None, 1, nn.LeakyReLU(0.1)]],
[[-1, 29], 1, Concat, [1]],
[-1, 1, Conv, [128, 1, 1, None, 1, nn.LeakyReLU(0.1)]],
[-2, 1, Conv, [128, 1, 1, None, 1, nn.LeakyReLU(0.1)]],
[-1, 1, Conv, [128, 3, 1, None, 1, nn.LeakyReLU(0.1)]],
[-1, 1, Conv, [128, 3, 1, None, 1, nn.LeakyReLU(0.1)]],
[[-1, -2, -3, -4], 1, Concat, [1]],
[-1, 1, Conv, [256, 1, 1, None, 1, nn.LeakyReLU(0.1)]], # 65
[49, 1, Conv, [128, 3, 1, None, 1, nn.LeakyReLU(0.1)]],
[57, 1, Conv, [256, 3, 1, None, 1, nn.LeakyReLU(0.1)]],
[65, 1, Conv, [512, 3, 1, None, 1, nn.LeakyReLU(0.1)]],
[[66, 67, 68], 1, IDetect, [nc, anchors]], # Detect(P3, P4, P5)
]
四、YOLOv7改进工作
了解二后,在YOLOv7项目文件下的models文件夹下的common.py,采用ctrl+F搜索如下代码。
self.act = nn.SiLU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity())
将其替换为如下所示的激活函数即可。
self.act = nn.SiLU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity())
# self.act = nn.ReLU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity())
# self.act = nn.LeakyReLU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity())
# self.act = nn.Hardswish() if act is True else (act if isinstance(act, nn.Module) else nn.Identity())
# self.act = nn.Mish() if act is True else (act if isinstance(act, nn.Module) else nn.Identity())
# self.act = nn.ELU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity())
# self.act = nn.GELU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity())
# self.act = nn.SELU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity())
# self.act = nn.RReLU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity())
# self.act = nn.PReLU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity())
五、YOLOv5改进工作
了解二后,在YOLOv5项目文件下的models文件夹下的common.py,采用ctrl+F搜索如下代码。
self.act = nn.SiLU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity())
将其替换为如下所示的激活函数即可。
self.act = nn.SiLU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity())
# self.act = nn.ReLU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity())
# self.act = nn.LeakyReLU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity())
# self.act = nn.Hardswish() if act is True else (act if isinstance(act, nn.Module) else nn.Identity())
# self.act = nn.Mish() if act is True else (act if isinstance(act, nn.Module) else nn.Identity())
# self.act = nn.ELU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity())
# self.act = nn.GELU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity())
# self.act = nn.SELU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity())
# self.act = nn.RReLU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity())
# self.act = nn.PReLU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity())
更多文章产出中,主打简洁和准确,欢迎关注我,共同探讨!