学习前言
AB哥弄了个YoloV7,我觉得有必要跟进看看,它的concat结构还是第一次见,感觉有点意思。
源码下载
https://github.com/bubbliiiing/yolov7-pytorch
喜欢的可以点个star噢。
YoloV7改进的部分(不完全)
1、主干部分:使用了创新的多分支堆叠结构进行特征提取,相比以前的Yolo,模型的跳连接结构更加的密集。使用了创新的下采样结构,使用Maxpooling和步长为2x2的特征并行进行提取与压缩。
2、加强特征提取部分:同主干部分,加强特征提取部分也使用了多输入堆叠结构进行特征提取,使用Maxpooling和步长为2x2的特征并行进行下采样。
3、特殊的SPP结构:使用了具有CSP机构的SPP扩大感受野,在SPP结构中引入了CSP结构,该模块具有一个大的残差边辅助优化与特征提取。
4、自适应多正样本匹配:在YoloV5之前的Yolo系列里面,在训练时每一个真实框对应一个正样本,即在训练时,每一个真实框仅由一个先验框负责预测。YoloV7中为了加快模型的训练效率,增加了正样本的数量,在训练时,每一个真实框可以由多个先验框负责预测。除此之外,对于每个真实框,还会根据先验框调整后的预测框进行iou与种类的计算,获得cost,进而找到最适合该真实框的先验框。
5、借鉴了RepVGG的结构,在网络的特定部分引入RepConv,fuse后在保证网络x减少网络的参数量
6、使用了辅助分支辅助收敛,但是在模型较小的YoloV7和YoloV7-X中并没有使用。
以上并非全部的改进部分,还存在一些其它的改进,这里只列出来了一些我比较感兴趣,而且非常有效的改进。
YoloV7实现思路
一、整体结构解析
在学习YoloV7之前,我们需要对YoloV7所作的工作有一定的了解,这有助于我们后面去了解网络的细节,YoloV7在预测方式上与之前的Yolo并没有多大的差别,依然分为三个部分。
分别是Backbone,FPN以及Yolo Head。
Backbone是YoloV7的主干特征提取网络,输入的图片首先会在主干网络里面进行特征提取,提取到的特征可以被称作特征层,是输入图片的特征集合。在主干部分,我们获取了三个特征层进行下一步网络的构建,这三个特征层我称它为有效特征层。
FPN是YoloV7的加强特征提取网络,在主干部分获得的三个有效特征层会在这一部分进行特征融合,特征融合的目的是结合不同尺度的特征信息。在FPN部分,已经获得的有效特征层被用于继续提取特征。在YoloV7里依然使用到了Panet的结构,我们不仅会对特征进行上采样实现特征融合,还会对特征再次进行下采样实现特征融合。
Yolo Head是YoloV7的分类器与回归器,通过Backbone和FPN,我们已经可以获得三个加强过的有效特征层。每一个特征层都有宽、高和通道数,此时我们可以将特征图看作一个又一个特征点的集合,每个特征点上有三个先验框,每一个先验框都有通道数个特征。Yolo Head实际上所做的工作就是对特征点进行判断,判断特征点上的先验框是否有物体与其对应。与以前版本的Yolo一样,YoloV7所用的解耦头是一起的,也就是分类和回归在一个1X1卷积里实现。
因此,整个YoloV7网络所作的工作就是 特征提取-特征加强-预测先验框对应的物体情况。
二、网络结构解析
1、主干网络Backbone介绍
YoloV7所使用的主干特征提取网络具有两个重要特点:
1、使用了多分支堆叠模块,这个模块其实论文里没有命名,但是我在分析源码后认为这个名字非常合适,在本博文中,多分支堆叠模块如图所示。
看了这幅图大家应该明白为什么我把这个模块称为多分支堆叠模块,因为在该模块中,最终堆叠模块的输入包含多个分支,左一为一个卷积标准化激活函数,左二为一个卷积标准化激活函数,右二为三个卷积标准化激活函数,右一为五个卷积标准化激活函数。
四个特征层在堆叠后会再次进行一个卷积标准化激活函数来特征整合。
class Multi_Concat_Block(nn.Module): def __init__(self, c1, c2, c3, n=4, e=1, ids=[0]): super(Multi_Concat_Block, self).__init__() c_ = int(c2 * e)
self<span class="token punctuation">.</span>ids <span class="token operator">=</span> ids self<span class="token punctuation">.</span>cv1 <span class="token operator">=</span> Conv<span class="token punctuation">(</span>c1<span class="token punctuation">,</span> c_<span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span> self<span class="token punctuation">.</span>cv2 <span class="token operator">=</span> Conv<span class="token punctuation">(</span>c1<span class="token punctuation">,</span> c_<span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span> self<span class="token punctuation">.</span>cv3 <span class="token operator">=</span> nn<span class="token punctuation">.</span>ModuleList<span class="token punctuation">(</span> <span class="token punctuation">[</span>Conv<span class="token punctuation">(</span>c_ <span class="token keyword">if</span> i <span class="token operator">==</span><span class="token number">0</span> <span class="token keyword">else</span> c2<span class="token punctuation">,</span> c2<span class="token punctuation">,</span> <span class="token number">3</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span> <span class="token keyword">for</span> i <span class="token keyword">in</span> <span class="token builtin">range</span><span class="token punctuation">(</span>n<span class="token punctuation">)</span><span class="token punctuation">]</span> <span class="token punctuation">)</span> self<span class="token punctuation">.</span>cv4 <span class="token operator">=</span> Conv<span class="token punctuation">(</span>c_ <span class="token operator">*</span> <span class="token number">2</span> <span class="token operator">+</span> c2 <span class="token operator">*</span> <span class="token punctuation">(</span><span class="token builtin">len</span><span class="token punctuation">(</span>ids<span class="token punctuation">)</span> <span class="token operator">-</span> <span class="token number">2</span><span class="token punctuation">)</span><span class="token punctuation">,</span> c3<span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span> <span class="token keyword">def</span> <span class="token function">forward</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> x<span class="token punctuation">)</span><span class="token punctuation">:</span> x_1 <span class="token operator">=</span> self<span class="token punctuation">.</span>cv1<span class="token punctuation">(</span>x<span class="token punctuation">)</span> x_2 <span class="token operator">=</span> self<span class="token punctuation">.</span>cv2<span class="token punctuation">(</span>x<span class="token punctuation">)</span> x_all <span class="token operator">=</span> <span class="token punctuation">[</span>x_1<span class="token punctuation">,</span> x_2<span class="token punctuation">]</span> <span class="token keyword">for</span> i <span class="token keyword">in</span> <span class="token builtin">range</span><span class="token punctuation">(</span><span class="token builtin">len</span><span class="token punctuation">(</span>self<span class="token punctuation">.</span>cv3<span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">:</span> x_2 <span class="token operator">=</span> self<span class="token punctuation">.</span>cv3<span class="token punctuation">[</span>i<span class="token punctuation">]</span><span class="token punctuation">(</span>x_2<span class="token punctuation">)</span> x_all<span class="token punctuation">.</span>append<span class="token punctuation">(</span>x_2<span class="token punctuation">)</span> out <span class="token operator">=</span> self<span class="token punctuation">.</span>cv4<span class="token punctuation">(</span>torch<span class="token punctuation">.</span>cat<span class="token punctuation">(</span><span class="token punctuation">[</span>x_all<span class="token punctuation">[</span><span class="token builtin">id</span><span class="token punctuation">]</span> <span class="token keyword">for</span> <span class="token builtin">id</span> <span class="token keyword">in</span> self<span class="token punctuation">.</span>ids<span class="token punctuation">]</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">)</span> <span class="token keyword">return</span> out
如此多的堆叠其实也对应了更密集的残差结构,残差网络的特点是容易优化,并且能够通过增加相当的深度来提高准确率。其内部的残差块使用了跳跃连接,缓解了在深度神经网络中增加深度带来的梯度消失问题。
2、使用创新的过渡模块Transition_Block来进行下采样,在卷积神经网络中,常见的用于下采样的过渡模块是一个卷积核大小为3x3、步长为2x2的卷积或者一个步长为2x2的最大池化。在YoloV7中,作者将两种过渡模块进行了集合,一个过渡模块存在两个分支,如图所示。左分支是一个步长为2x2的最大池化+一个1x1卷积,右分支是一个1x1卷积+一个卷积核大小为3x3、步长为2x2的卷积,两个分支的结果在输出时会进行堆叠。
class MP(nn.Module): def __init__(self, k=2): super(MP, self).__init__() self.m = nn.MaxPool2d(kernel_size=k, stride=k)
<span class="token keyword">def</span> <span class="token function">forward</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> x<span class="token punctuation">)</span><span class="token punctuation">:</span> <span class="token keyword">return</span> self<span class="token punctuation">.</span>m<span class="token punctuation">(</span>x<span class="token punctuation">)</span>
class Transition_Block(nn.Module):
def init(self, c1, c2):
super(Transition_Block, self).init()
self.cv1 = Conv(c1, c2, 1, 1)
self.cv2 = Conv(c1, c2, 1, 1)
self.cv3 = Conv(c2, c2, 3, 2)
self<span class="token punctuation">.</span>mp <span class="token operator">=</span> MP<span class="token punctuation">(</span><span class="token punctuation">)</span>
<span class="token keyword">def</span> <span class="token function">forward</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> x<span class="token punctuation">)</span><span class="token punctuation">:</span>
x_1 <span class="token operator">=</span> self<span class="token punctuation">.</span>mp<span class="token punctuation">(</span>x<span class="token punctuation">)</span>
x_1 <span class="token operator">=</span> self<span class="token punctuation">.</span>cv1<span class="token punctuation">(</span>x_1<span class="token punctuation">)</span>
x_2 <span class="token operator">=</span> self<span class="token punctuation">.</span>cv2<span class="token punctuation">(</span>x<span class="token punctuation">)</span>
x_2 <span class="token operator">=</span> self<span class="token punctuation">.</span>cv3<span class="token punctuation">(</span>x_2<span class="token punctuation">)</span>
<span class="token keyword">return</span> torch<span class="token punctuation">.</span>cat<span class="token punctuation">(</span><span class="token punctuation">[</span>x_2<span class="token punctuation">,</span> x_1<span class="token punctuation">]</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span>
整个主干实现代码为:
import torch
import torch.nn as nn
def autopad(k, p=None):
if p is None:
p = k // 2 if isinstance(k, int) else [x // 2 for x in k]
return p
class SiLU(nn.Module):
@staticmethod
def forward(x):
return x * torch.sigmoid(x)
class Conv(nn.Module):
def init(self, c1, c2, k=1, s=1, p=None, g=1, act=SiLU()): # ch_in, ch_out, kernel, stride, padding, groups
super(Conv, self).init()
self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)
self.bn = nn.BatchNorm2d(c2, eps=0.001, momentum=0.03)
self.act = nn.LeakyReLU(0.1, inplace=True) if act is True else (act if isinstance(act, nn.Module) else nn.Identity())
<span class="token keyword">def</span> <span class="token function">forward</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> x<span class="token punctuation">)</span><span class="token punctuation">:</span>
<span class="token keyword">return</span> self<span class="token punctuation">.</span>act<span class="token punctuation">(</span>self<span class="token punctuation">.</span>bn<span class="token punctuation">(</span>self<span class="token punctuation">.</span>conv<span class="token punctuation">(</span>x<span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">)</span>
<span class="token keyword">def</span> <span class="token function">fuseforward</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> x<span class="token punctuation">)</span><span class="token punctuation">:</span>
<span class="token keyword">return</span> self<span class="token punctuation">.</span>act<span class="token punctuation">(</span>self<span class="token punctuation">.</span>conv<span class="token punctuation">(</span>x<span class="token punctuation">)</span><span class="token punctuation">)</span>
class Multi_Concat_Block(nn.Module):
def init(self, c1, c2, c3, n=4, e=1, ids=[0]):
super(Multi_Concat_Block, self).init()
c_ = int(c2 * e)
self<span class="token punctuation">.</span>ids <span class="token operator">=</span> ids
self<span class="token punctuation">.</span>cv1 <span class="token operator">=</span> Conv<span class="token punctuation">(</span>c1<span class="token punctuation">,</span> c_<span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span>
self<span class="token punctuation">.</span>cv2 <span class="token operator">=</span> Conv<span class="token punctuation">(</span>c1<span class="token punctuation">,</span> c_<span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span>
self<span class="token punctuation">.</span>cv3 <span class="token operator">=</span> nn<span class="token punctuation">.</span>ModuleList<span class="token punctuation">(</span>
<span class="token punctuation">[</span>Conv<span class="token punctuation">(</span>c_ <span class="token keyword">if</span> i <span class="token operator">==</span><span class="token number">0</span> <span class="token keyword">else</span> c2<span class="token punctuation">,</span> c2<span class="token punctuation">,</span> <span class="token number">3</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span> <span class="token keyword">for</span> i <span class="token keyword">in</span> <span class="token builtin">range</span><span class="token punctuation">(</span>n<span class="token punctuation">)</span><span class="token punctuation">]</span>
<span class="token punctuation">)</span>
self<span class="token punctuation">.</span>cv4 <span class="token operator">=</span> Conv<span class="token punctuation">(</span>c_ <span class="token operator">*</span> <span class="token number">2</span> <span class="token operator">+</span> c2 <span class="token operator">*</span> <span class="token punctuation">(</span><span class="token builtin">len</span><span class="token punctuation">(</span>ids<span class="token punctuation">)</span> <span class="token operator">-</span> <span class="token number">2</span><span class="token punctuation">)</span><span class="token punctuation">,</span> c3<span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span>
<span class="token keyword">def</span> <span class="token function">forward</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> x<span class="token punctuation">)</span><span class="token punctuation">:</span>
x_1 <span class="token operator">=</span> self<span class="token punctuation">.</span>cv1<span class="token punctuation">(</span>x<span class="token punctuation">)</span>
x_2 <span class="token operator">=</span> self<span class="token punctuation">.</span>cv2<span class="token punctuation">(</span>x<span class="token punctuation">)</span>
x_all <span class="token operator">=</span> <span class="token punctuation">[</span>x_1<span class="token punctuation">,</span> x_2<span class="token punctuation">]</span>
<span class="token keyword">for</span> i <span class="token keyword">in</span> <span class="token builtin">range</span><span class="token punctuation">(</span><span class="token builtin">len</span><span class="token punctuation">(</span>self<span class="token punctuation">.</span>cv3<span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">:</span>
x_2 <span class="token operator">=</span> self<span class="token punctuation">.</span>cv3<span class="token punctuation">[</span>i<span class="token punctuation">]</span><span class="token punctuation">(</span>x_2<span class="token punctuation">)</span>
x_all<span class="token punctuation">.</span>append<span class="token punctuation">(</span>x_2<span class="token punctuation">)</span>
out <span class="token operator">=</span> self<span class="token punctuation">.</span>cv4<span class="token punctuation">(</span>torch<span class="token punctuation">.</span>cat<span class="token punctuation">(</span><span class="token punctuation">[</span>x_all<span class="token punctuation">[</span><span class="token builtin">id</span><span class="token punctuation">]</span> <span class="token keyword">for</span> <span class="token builtin">id</span> <span class="token keyword">in</span> self<span class="token punctuation">.</span>ids<span class="token punctuation">]</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">)</span>
<span class="token keyword">return</span> out
class MP(nn.Module):
def init(self, k=2):
super(MP, self).init()
self.m = nn.MaxPool2d(kernel_size=k, stride=k)
<span class="token keyword">def</span> <span class="token function">forward</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> x<span class="token punctuation">)</span><span class="token punctuation">:</span>
<span class="token keyword">return</span> self<span class="token punctuation">.</span>m<span class="token punctuation">(</span>x<span class="token punctuation">)</span>
class Transition_Block(nn.Module):
def init(self, c1, c2):
super(Transition_Block, self).init()
self.cv1 = Conv(c1, c2, 1, 1)
self.cv2 = Conv(c1, c2, 1, 1)
self.cv3 = Conv(c2, c2, 3, 2)
self<span class="token punctuation">.</span>mp <span class="token operator">=</span> MP<span class="token punctuation">(</span><span class="token punctuation">)</span>
<span class="token keyword">def</span> <span class="token function">forward</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> x<span class="token punctuation">)</span><span class="token punctuation">:</span>
x_1 <span class="token operator">=</span> self<span class="token punctuation">.</span>mp<span class="token punctuation">(</span>x<span class="token punctuation">)</span>
x_1 <span class="token operator">=</span> self<span class="token punctuation">.</span>cv1<span class="token punctuation">(</span>x_1<span class="token punctuation">)</span>
x_2 <span class="token operator">=</span> self<span class="token punctuation">.</span>cv2<span class="token punctuation">(</span>x<span class="token punctuation">)</span>
x_2 <span class="token operator">=</span> self<span class="token punctuation">.</span>cv3<span class="token punctuation">(</span>x_2<span class="token punctuation">)</span>
<span class="token keyword">return</span> torch<span class="token punctuation">.</span>cat<span class="token punctuation">(</span><span class="token punctuation">[</span>x_2<span class="token punctuation">,</span> x_1<span class="token punctuation">]</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span>
class Backbone(nn.Module):
def init(self, transition_channels, block_channels, n, phi, pretrained=False):
super().init()
#-----------------------------------------------#
# 输入图片是640, 640, 3
#-----------------------------------------------#
ids = {
‘l’ : [-1, -3, -5, -6],
‘x’ : [-1, -3, -5, -7, -8],
}[phi]
self.stem = nn.Sequential(
Conv(3, transition_channels, 3, 1),
Conv(transition_channels, transition_channels 2, 3, 2),
Conv(transition_channels 2, transition_channels 2, 3, 1),
)
self.dark2 = nn.Sequential(
Conv(transition_channels 2, transition_channels 4, 3, 2),
Multi_Concat_Block(transition_channels 4, block_channels 2, transition_channels 8, n=n, ids=ids),
)
self.dark3 = nn.Sequential(
Transition_Block(transition_channels 8, transition_channels 4),
Multi_Concat_Block(transition_channels 8, block_channels 4, transition_channels 16, n=n, ids=ids),
)
self.dark4 = nn.Sequential(
Transition_Block(transition_channels 16, transition_channels 8),
Multi_Concat_Block(transition_channels 16, block_channels 8, transition_channels 32, n=n, ids=ids),
)
self.dark5 = nn.Sequential(
Transition_Block(transition_channels 32, transition_channels 16),
Multi_Concat_Block(transition_channels 32, block_channels 8, transition_channels * 32, n=n, ids=ids),
)
<span class="token keyword">if</span> pretrained<span class="token punctuation">:</span>
url <span class="token operator">=</span> <span class="token punctuation">{<!-- --></span>
<span class="token string">"l"</span> <span class="token punctuation">:</span> <span class="token string">'https://github.com/bubbliiiing/yolov7-pytorch/releases/download/v1.0/yolov7_backbone_weights.pth'</span><span class="token punctuation">,</span>
<span class="token string">"x"</span> <span class="token punctuation">:</span> <span class="token string">'https://github.com/bubbliiiing/yolov7-pytorch/releases/download/v1.0/yolov7_x_backbone_weights.pth'</span><span class="token punctuation">,</span>
<span class="token punctuation">}</span><span class="token punctuation">[</span>phi<span class="token punctuation">]</span>
checkpoint <span class="token operator">=</span> torch<span class="token punctuation">.</span>hub<span class="token punctuation">.</span>load_state_dict_from_url<span class="token punctuation">(</span>url<span class="token operator">=</span>url<span class="token punctuation">,</span> map_location<span class="token operator">=</span><span class="token string">"cpu"</span><span class="token punctuation">,</span> model_dir<span class="token operator">=</span><span class="token string">"./model_data"</span><span class="token punctuation">)</span>
self<span class="token punctuation">.</span>load_state_dict<span class="token punctuation">(</span>checkpoint<span class="token punctuation">,</span> strict<span class="token operator">=</span><span class="token boolean">False</span><span class="token punctuation">)</span>
<span class="token keyword">print</span><span class="token punctuation">(</span><span class="token string">"Load weights from "</span> <span class="token operator">+</span> url<span class="token punctuation">.</span>split<span class="token punctuation">(</span><span class="token string">'/'</span><span class="token punctuation">)</span><span class="token punctuation">[</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
<span class="token keyword">def</span> <span class="token function">forward</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> x<span class="token punctuation">)</span><span class="token punctuation">:</span>
x <span class="token operator">=</span> self<span class="token punctuation">.</span>stem<span class="token punctuation">(</span>x<span class="token punctuation">)</span>
x <span class="token operator">=</span> self<span class="token punctuation">.</span>dark2<span class="token punctuation">(</span>x<span class="token punctuation">)</span>
<span class="token comment">#-----------------------------------------------#</span>
<span class="token comment"># dark3的输出为80, 80, 512,是一个有效特征层</span>
<span class="token comment">#-----------------------------------------------#</span>
x <span class="token operator">=</span> self<span class="token punctuation">.</span>dark3<span class="token punctuation">(</span>x<span class="token punctuation">)</span>
feat1 <span class="token operator">=</span> x
<span class="token comment">#-----------------------------------------------#</span>
<span class="token comment"># dark4的输出为40, 40, 1024,是一个有效特征层</span>
<span class="token comment">#-----------------------------------------------#</span>
x <span class="token operator">=</span> self<span class="token punctuation">.</span>dark4<span class="token punctuation">(</span>x<span class="token punctuation">)</span>
feat2 <span class="token operator">=</span> x
<span class="token comment">#-----------------------------------------------#</span>
<span class="token comment"># dark5的输出为20, 20, 1024,是一个有效特征层</span>
<span class="token comment">#-----------------------------------------------#</span>
x <span class="token operator">=</span> self<span class="token punctuation">.</span>dark5<span class="token punctuation">(</span>x<span class="token punctuation">)</span>
feat3 <span class="token operator">=</span> x
<span class="token keyword">return</span> feat1<span class="token punctuation">,</span> feat2<span class="token punctuation">,</span> feat3
2、构建FPN特征金字塔进行加强特征提取
在特征利用部分,YoloV7提取多特征层进行目标检测,一共提取三个特征层。
三个特征层位于主干部分的不同位置,分别位于中间层,中下层,底层,当输入为(640,640,3)的时候,三个特征层的shape分别为feat1=(80,80,512)、feat2=(40,40,1024)、feat3=(20,20,1024)。
在获得三个有效特征层后,我们利用这三个有效特征层进行FPN层的构建,构建方式为(在本博文中,将SPPCSPC结构归于FPN中):
- feat3=(20,20,1024)的特征层首先利用SPPCSPC进行特征提取,该结构可以提高YoloV7的感受野,获得P5。
- 对P5先进行1次1X1卷积调整通道,然后进行上采样UmSampling2d后与feat2=(40,40,1024)进行一次卷积后的特征层进行结合,然后使用Multi_Concat_Block进行特征提取获得P4,此时获得的特征层为(40,40,256)。
- 对P4先进行1次1X1卷积调整通道,然后进行上采样UmSampling2d后与feat1=(80,80,512)进行一次卷积后的特征层进行结合,然后使用Multi_Concat_Block进行特征提取获得P3_out,此时获得的特征层为(80,80,128)。
- P3_out=(80,80,128)的特征层进行一次Transition_Block卷积进行下采样,下采样后与P4堆叠,然后使用Multi_Concat_Block进行特征提取P4_out,此时获得的特征层为(40,40,256)。
- P4_out=(40,40,256)的特征层进行一次Transition_Block卷积进行下采样,下采样后与P5堆叠,然后使用Multi_Concat_Block进行特征提取P5_out,此时获得的特征层为(20,20,512)。
特征金字塔可以将不同shape的特征层进行特征融合,有利于提取出更好的特征。
#---------------------------------------------------# # yolo_body #---------------------------------------------------# class YoloBody(nn.Module): def __init__(self, anchors_mask, num_classes, phi, pretrained=False): super(YoloBody, self).__init__() #-----------------------------------------------# # 定义了不同yolov7版本的参数 #-----------------------------------------------# transition_channels = {'l' : 32, 'x' : 40}[phi] block_channels = 32 panet_channels = {'l' : 32, 'x' : 64}[phi] e = {'l' : 2, 'x' : 1}[phi] n = {'l' : 4, 'x' : 6}[phi] ids = {'l' : [-1, -2, -3, -4, -5, -6], 'x' : [-1, -3, -5, -7, -8]}[phi] conv = {'l' : RepConv, 'x' : Conv}[phi] #-----------------------------------------------# # 输入图片是640, 640, 3 #-----------------------------------------------#
<span class="token comment">#---------------------------------------------------# </span> <span class="token comment"># 生成主干模型</span> <span class="token comment"># 获得三个有效特征层,他们的shape分别是:</span> <span class="token comment"># 80, 80, 512</span> <span class="token comment"># 40, 40, 1024</span> <span class="token comment"># 20, 20, 1024</span> <span class="token comment">#---------------------------------------------------#</span> self<span class="token punctuation">.</span>backbone <span class="token operator">=</span> Backbone<span class="token punctuation">(</span>transition_channels<span class="token punctuation">,</span> block_channels<span class="token punctuation">,</span> n<span class="token punctuation">,</span> phi<span class="token punctuation">,</span> pretrained<span class="token operator">=</span>pretrained<span class="token punctuation">)</span> self<span class="token punctuation">.</span>upsample <span class="token operator">=</span> nn<span class="token punctuation">.</span>Upsample<span class="token punctuation">(</span>scale_factor<span class="token operator">=</span><span class="token number">2</span><span class="token punctuation">,</span> mode<span class="token operator">=</span><span class="token string">"nearest"</span><span class="token punctuation">)</span> self<span class="token punctuation">.</span>sppcspc <span class="token operator">=</span> SPPCSPC<span class="token punctuation">(</span>transition_channels <span class="token operator">*</span> <span class="token number">32</span><span class="token punctuation">,</span> transition_channels <span class="token operator">*</span> <span class="token number">16</span><span class="token punctuation">)</span> self<span class="token punctuation">.</span>conv_for_P5 <span class="token operator">=</span> Conv<span class="token punctuation">(</span>transition_channels <span class="token operator">*</span> <span class="token number">16</span><span class="token punctuation">,</span> transition_channels <span class="token operator">*</span> <span class="token number">8</span><span class="token punctuation">)</span> self<span class="token punctuation">.</span>conv_for_feat2 <span class="token operator">=</span> Conv<span class="token punctuation">(</span>transition_channels <span class="token operator">*</span> <span class="token number">32</span><span class="token punctuation">,</span> transition_channels <span class="token operator">*</span> <span class="token number">8</span><span class="token punctuation">)</span> self<span class="token punctuation">.</span>conv3_for_upsample1 <span class="token operator">=</span> Multi_Concat_Block<span class="token punctuation">(</span>transition_channels <span class="token operator">*</span> <span class="token number">16</span><span class="token punctuation">,</span> panet_channels <span class="token operator">*</span> <span class="token number">4</span><span class="token punctuation">,</span> transition_channels <span class="token operator">*</span> <span class="token number">8</span><span class="token punctuation">,</span> e<span class="token operator">=</span>e<span class="token punctuation">,</span> n<span class="token operator">=</span>n<span class="token punctuation">,</span> ids<span class="token operator">=</span>ids<span class="token punctuation">)</span> self<span class="token punctuation">.</span>conv_for_P4 <span class="token operator">=</span> Conv<span class="token punctuation">(</span>transition_channels <span class="token operator">*</span> <span class="token number">8</span><span class="token punctuation">,</span> transition_channels <span class="token operator">*</span> <span class="token number">4</span><span class="token punctuation">)</span> self<span class="token punctuation">.</span>conv_for_feat1 <span class="token operator">=</span> Conv<span class="token punctuation">(</span>transition_channels <span class="token operator">*</span> <span class="token number">16</span><span class="token punctuation">,</span> transition_channels <span class="token operator">*</span> <span class="token number">4</span><span class="token punctuation">)</span> self<span class="token punctuation">.</span>conv3_for_upsample2 <span class="token operator">=</span> Multi_Concat_Block<span class="token punctuation">(</span>transition_channels <span class="token operator">*</span> <span class="token number">8</span><span class="token punctuation">,</span> panet_channels <span class="token operator">*</span> <span class="token number">2</span><span class="token punctuation">,</span> transition_channels <span class="token operator">*</span> <span class="token number">4</span><span class="token punctuation">,</span> e<span class="token operator">=</span>e<span class="token punctuation">,</span> n<span class="token operator">=</span>n<span class="token punctuation">,</span> ids<span class="token operator">=</span>ids<span class="token punctuation">)</span> self<span class="token punctuation">.</span>down_sample1 <span class="token operator">=</span> Transition_Block<span class="token punctuation">(</span>transition_channels <span class="token operator">*</span> <span class="token number">4</span><span class="token punctuation">,</span> transition_channels <span class="token operator">*</span> <span class="token number">4</span><span class="token punctuation">)</span> self<span class="token punctuation">.</span>conv3_for_downsample1 <span class="token operator">=</span> Multi_Concat_Block<span class="token punctuation">(</span>transition_channels <span class="token operator">*</span> <span class="token number">16</span><span class="token punctuation">,</span> panet_channels <span class="token operator">*</span> <span class="token number">4</span><span class="token punctuation">,</span> transition_channels <span class="token operator">*</span> <span class="token number">8</span><span class="token punctuation">,</span> e<span class="token operator">=</span>e<span class="token punctuation">,</span> n<span class="token operator">=</span>n<span class="token punctuation">,</span> ids<span class="token operator">=</span>ids<span class="token punctuation">)</span> self<span class="token punctuation">.</span>down_sample2 <span class="token operator">=</span> Transition_Block<span class="token punctuation">(</span>transition_channels <span class="token operator">*</span> <span class="token number">8</span><span class="token punctuation">,</span> transition_channels <span class="token operator">*</span> <span class="token number">8</span><span class="token punctuation">)</span> self<span class="token punctuation">.</span>conv3_for_downsample2 <span class="token operator">=</span> Multi_Concat_Block<span class="token punctuation">(</span>transition_channels <span class="token operator">*</span> <span class="token number">32</span><span class="token punctuation">,</span> panet_channels <span class="token operator">*</span> <span class="token number">8</span><span class="token punctuation">,</span> transition_channels <span class="token operator">*</span> <span class="token number">16</span><span class="token punctuation">,</span> e<span class="token operator">=</span>e<span class="token punctuation">,</span> n<span class="token operator">=</span>n<span class="token punctuation">,</span> ids<span class="token operator">=</span>ids<span class="token punctuation">)</span> self<span class="token punctuation">.</span>rep_conv_1 <span class="token operator">=</span> conv<span class="token punctuation">(</span>transition_channels <span class="token operator">*</span> <span class="token number">4</span><span class="token punctuation">,</span> transition_channels <span class="token operator">*</span> <span class="token number">8</span><span class="token punctuation">,</span> <span class="token number">3</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span> self<span class="token punctuation">.</span>rep_conv_2 <span class="token operator">=</span> conv<span class="token punctuation">(</span>transition_channels <span class="token operator">*</span> <span class="token number">8</span><span class="token punctuation">,</span> transition_channels <span class="token operator">*</span> <span class="token number">16</span><span class="token punctuation">,</span> <span class="token number">3</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span> self<span class="token punctuation">.</span>rep_conv_3 <span class="token operator">=</span> conv<span class="token punctuation">(</span>transition_channels <span class="token operator">*</span> <span class="token number">16</span><span class="token punctuation">,</span> transition_channels <span class="token operator">*</span> <span class="token number">32</span><span class="token punctuation">,</span> <span class="token number">3</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span> self<span class="token punctuation">.</span>yolo_head_P3 <span class="token operator">=</span> nn<span class="token punctuation">.</span>Conv2d<span class="token punctuation">(</span>transition_channels <span class="token operator">*</span> <span class="token number">8</span><span class="token punctuation">,</span> <span class="token builtin">len</span><span class="token punctuation">(</span>anchors_mask<span class="token punctuation">[</span><span class="token number">2</span><span class="token punctuation">]</span><span class="token punctuation">)</span> <span class="token operator">*</span> <span class="token punctuation">(</span><span class="token number">5</span> <span class="token operator">+</span> num_classes<span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span> self<span class="token punctuation">.</span>yolo_head_P4 <span class="token operator">=</span> nn<span class="token punctuation">.</span>Conv2d<span class="token punctuation">(</span>transition_channels <span class="token operator">*</span> <span class="token number">16</span><span class="token punctuation">,</span> <span class="token builtin">len</span><span class="token punctuation">(</span>anchors_mask<span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">)</span> <span class="token operator">*</span> <span class="token punctuation">(</span><span class="token number">5</span> <span class="token operator">+</span> num_classes<span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span> self<span class="token punctuation">.</span>yolo_head_P5 <span class="token operator">=</span> nn<span class="token punctuation">.</span>Conv2d<span class="token punctuation">(</span>transition_channels <span class="token operator">*</span> <span class="token number">32</span><span class="token punctuation">,</span> <span class="token builtin">len</span><span class="token punctuation">(</span>anchors_mask<span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">)</span> <span class="token operator">*</span> <span class="token punctuation">(</span><span class="token number">5</span> <span class="token operator">+</span> num_classes<span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span> <span class="token keyword">def</span> <span class="token function">fuse</span><span class="token punctuation">(</span>self<span class="token punctuation">)</span><span class="token punctuation">:</span> <span class="token keyword">print</span><span class="token punctuation">(</span><span class="token string">'Fusing layers... '</span><span class="token punctuation">)</span> <span class="token keyword">for</span> m <span class="token keyword">in</span> self<span class="token punctuation">.</span>modules<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">:</span> <span class="token keyword">if</span> <span class="token builtin">isinstance</span><span class="token punctuation">(</span>m<span class="token punctuation">,</span> RepConv<span class="token punctuation">)</span><span class="token punctuation">:</span> m<span class="token punctuation">.</span>fuse_repvgg_block<span class="token punctuation">(</span><span class="token punctuation">)</span> <span class="token keyword">elif</span> <span class="token builtin">type</span><span class="token punctuation">(</span>m<span class="token punctuation">)</span> <span class="token keyword">is</span> Conv <span class="token keyword">and</span> <span class="token builtin">hasattr</span><span class="token punctuation">(</span>m<span class="token punctuation">,</span> <span class="token string">'bn'</span><span class="token punctuation">)</span><span class="token punctuation">:</span> m<span class="token punctuation">.</span>conv <span class="token operator">=</span> fuse_conv_and_bn<span class="token punctuation">(</span>m<span class="token punctuation">.</span>conv<span class="token punctuation">,</span> m<span class="token punctuation">.</span>bn<span class="token punctuation">)</span> <span class="token builtin">delattr</span><span class="token punctuation">(</span>m<span class="token punctuation">,</span> <span class="token string">'bn'</span><span class="token punctuation">)</span> m<span class="token punctuation">.</span>forward <span class="token operator">=</span> m<span class="token punctuation">.</span>fuseforward <span class="token keyword">return</span> self <span class="token keyword">def</span> <span class="token function">forward</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> x<span class="token punctuation">)</span><span class="token punctuation">:</span> <span class="token comment"># backbone</span> feat1<span class="token punctuation">,</span> feat2<span class="token punctuation">,</span> feat3 <span class="token operator">=</span> self<span class="token punctuation">.</span>backbone<span class="token punctuation">.</span>forward<span class="token punctuation">(</span>x<span class="token punctuation">)</span> P5 <span class="token operator">=</span> self<span class="token punctuation">.</span>sppcspc<span class="token punctuation">(</span>feat3<span class="token punctuation">)</span> P5_conv <span class="token operator">=</span> self<span class="token punctuation">.</span>conv_for_P5<span class="token punctuation">(</span>P5<span class="token punctuation">)</span> P5_upsample <span class="token operator">=</span> self<span class="token punctuation">.</span>upsample<span class="token punctuation">(</span>P5_conv<span class="token punctuation">)</span> P4 <span class="token operator">=</span> torch<span class="token punctuation">.</span>cat<span class="token punctuation">(</span><span class="token punctuation">[</span>self<span class="token punctuation">.</span>conv_for_feat2<span class="token punctuation">(</span>feat2<span class="token punctuation">)</span><span class="token punctuation">,</span> P5_upsample<span class="token punctuation">]</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span> P4 <span class="token operator">=</span> self<span class="token punctuation">.</span>conv3_for_upsample1<span class="token punctuation">(</span>P4<span class="token punctuation">)</span> P4_conv <span class="token operator">=</span> self<span class="token punctuation">.</span>conv_for_P4<span class="token punctuation">(</span>P4<span class="token punctuation">)</span> P4_upsample <span class="token operator">=</span> self<span class="token punctuation">.</span>upsample<span class="token punctuation">(</span>P4_conv<span class="token punctuation">)</span> P3 <span class="token operator">=</span> torch<span class="token punctuation">.</span>cat<span class="token punctuation">(</span><span class="token punctuation">[</span>self<span class="token punctuation">.</span>conv_for_feat1<span class="token punctuation">(</span>feat1<span class="token punctuation">)</span><span class="token punctuation">,</span> P4_upsample<span class="token punctuation">]</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span> P3 <span class="token operator">=</span> self<span class="token punctuation">.</span>conv3_for_upsample2<span class="token punctuation">(</span>P3<span class="token punctuation">)</span> P3_downsample <span class="token operator">=</span> self<span class="token punctuation">.</span>down_sample1<span class="token punctuation">(</span>P3<span class="token punctuation">)</span> P4 <span class="token operator">=</span> torch<span class="token punctuation">.</span>cat<span class="token punctuation">(</span><span class="token punctuation">[</span>P3_downsample<span class="token punctuation">,</span> P4<span class="token punctuation">]</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span> P4 <span class="token operator">=</span> self<span class="token punctuation">.</span>conv3_for_downsample1<span class="token punctuation">(</span>P4<span class="token punctuation">)</span> P4_downsample <span class="token operator">=</span> self<span class="token punctuation">.</span>down_sample2<span class="token punctuation">(</span>P4<span class="token punctuation">)</span> P5 <span class="token operator">=</span> torch<span class="token punctuation">.</span>cat<span class="token punctuation">(</span><span class="token punctuation">[</span>P4_downsample<span class="token punctuation">,</span> P5<span class="token punctuation">]</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span> P5 <span class="token operator">=</span> self<span class="token punctuation">.</span>conv3_for_downsample2<span class="token punctuation">(</span>P5<span class="token punctuation">)</span> P3 <span class="token operator">=</span> self<span class="token punctuation">.</span>rep_conv_1<span class="token punctuation">(</span>P3<span class="token punctuation">)</span> P4 <span class="token operator">=</span> self<span class="token punctuation">.</span>rep_conv_2<span class="token punctuation">(</span>P4<span class="token punctuation">)</span> P5 <span class="token operator">=</span> self<span class="token punctuation">.</span>rep_conv_3<span class="token punctuation">(</span>P5<span class="token punctuation">)</span> <span class="token comment">#---------------------------------------------------#</span> <span class="token comment"># 第三个特征层</span> <span class="token comment"># y3=(batch_size, 75, 80, 80)</span> <span class="token comment">#---------------------------------------------------#</span> out2 <span class="token operator">=</span> self<span class="token punctuation">.</span>yolo_head_P3<span class="token punctuation">(</span>P3<span class="token punctuation">)</span> <span class="token comment">#---------------------------------------------------#</span> <span class="token comment"># 第二个特征层</span> <span class="token comment"># y2=(batch_size, 75, 40, 40)</span> <span class="token comment">#---------------------------------------------------#</span> out1 <span class="token operator">=</span> self<span class="token punctuation">.</span>yolo_head_P4<span class="token punctuation">(</span>P4<span class="token punctuation">)</span> <span class="token comment">#---------------------------------------------------#</span> <span class="token comment"># 第一个特征层</span> <span class="token comment"># y1=(batch_size, 75, 20, 20)</span> <span class="token comment">#---------------------------------------------------#</span> out0 <span class="token operator">=</span> self<span class="token punctuation">.</span>yolo_head_P5<span class="token punctuation">(</span>P5<span class="token punctuation">)</span> <span class="token keyword">return</span> <span class="token punctuation">[</span>out0<span class="token punctuation">,</span> out1<span class="token punctuation">,</span> out2<span class="token punctuation">]</span>
3、利用Yolo Head获得预测结果
利用FPN特征金字塔,我们可以获得三个加强特征,这三个加强特征的shape分别为(20,20,512)、(40,40,256)、(80,80,128),然后我们利用这三个shape的特征层传入Yolo Head获得预测结果。
与之前Yolo系列不同的是,YoloV7在Yolo Head前使用了一个RepConv的结构,这个RepConv的思想取自于RepVGG,基本思想就是在训练的时候引入特殊的残差结构辅助训练,这个残差结构是经过独特设计的,在实际预测的时候,可以将复杂的残差结构等效于一个普通的3x3卷积,这个时候网络的复杂度就下降了,但网络的预测性能却没有下降。
而对于每一个特征层,我们可以获得利用一个卷积调整通道数,最终的通道数和需要区分的种类个数相关,在YoloV7里,每一个特征层上每一个特征点存在3个先验框。
如果使用的是voc训练集,类则为20种,最后的维度应该为75 = 3x25,三个特征层的shape为(20,20,75),(40,40,75),(80,80,75)。
最后的75可以拆分成3个25,对应3个先验框的25个参数,25可以拆分成4+1+20。
前4个参数用于判断每一个特征点的回归参数,回归参数调整后可以获得预测框;
第5个参数用于判断每一个特征点是否包含物体;
最后20个参数用于判断每一个特征点所包含的物体种类。
如果使用的是coco训练集,类则为80种,最后的维度应该为255 = 3x85,三个特征层的shape为(20,20,255),(40,40,255),(80,80,255)
最后的255可以拆分成3个85,对应3个先验框的85个参数,85可以拆分成4+1+80。
前4个参数用于判断每一个特征点的回归参数,回归参数调整后可以获得预测框;
第5个参数用于判断每一个特征点是否包含物体;
最后80个参数用于判断每一个特征点所包含的物体种类。
实现代码如下:
import numpy as np
import torch
import torch.nn as nn
from nets.backbone import Backbone, Multi_Concat_Block, Conv, SiLU, Transition_Block, autopad
class RepConv(nn.Module):
# Represented convolution
# https://arxiv.org/abs/2101.03697
def init(self, c1, c2, k=3, s=1, p=None, g=1, act=SiLU(), deploy=False):
super(RepConv, self).init()
self.deploy = deploy
self.groups = g
self.in_channels = c1
self.out_channels = c2
<span class="token keyword">assert</span> k <span class="token operator">==</span> <span class="token number">3</span>
<span class="token keyword">assert</span> autopad<span class="token punctuation">(</span>k<span class="token punctuation">,</span> p<span class="token punctuation">)</span> <span class="token operator">==</span> <span class="token number">1</span>
padding_11 <span class="token operator">=</span> autopad<span class="token punctuation">(</span>k<span class="token punctuation">,</span> p<span class="token punctuation">)</span> <span class="token operator">-</span> k <span class="token operator">//</span> <span class="token number">2</span>
self<span class="token punctuation">.</span>act <span class="token operator">=</span> nn<span class="token punctuation">.</span>LeakyReLU<span class="token punctuation">(</span><span class="token number">0.1</span><span class="token punctuation">,</span> inplace<span class="token operator">=</span><span class="token boolean">True</span><span class="token punctuation">)</span> <span class="token keyword">if</span> act <span class="token keyword">is</span> <span class="token boolean">True</span> <span class="token keyword">else</span> <span class="token punctuation">(</span>act <span class="token keyword">if</span> <span class="token builtin">isinstance</span><span class="token punctuation">(</span>act<span class="token punctuation">,</span> nn<span class="token punctuation">.</span>Module<span class="token punctuation">)</span> <span class="token keyword">else</span> nn<span class="token punctuation">.</span>Identity<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">)</span>
<span class="token keyword">if</span> deploy<span class="token punctuation">:</span>
self<span class="token punctuation">.</span>rbr_reparam <span class="token operator">=</span> nn<span class="token punctuation">.</span>Conv2d<span class="token punctuation">(</span>c1<span class="token punctuation">,</span> c2<span class="token punctuation">,</span> k<span class="token punctuation">,</span> s<span class="token punctuation">,</span> autopad<span class="token punctuation">(</span>k<span class="token punctuation">,</span> p<span class="token punctuation">)</span><span class="token punctuation">,</span> groups<span class="token operator">=</span>g<span class="token punctuation">,</span> bias<span class="token operator">=</span><span class="token boolean">True</span><span class="token punctuation">)</span>
<span class="token keyword">else</span><span class="token punctuation">:</span>
self<span class="token punctuation">.</span>rbr_identity <span class="token operator">=</span> <span class="token punctuation">(</span>nn<span class="token punctuation">.</span>BatchNorm2d<span class="token punctuation">(</span>num_features<span class="token operator">=</span>c1<span class="token punctuation">,</span> eps<span class="token operator">=</span><span class="token number">0.001</span><span class="token punctuation">,</span> momentum<span class="token operator">=</span><span class="token number">0.03</span><span class="token punctuation">)</span> <span class="token keyword">if</span> c2 <span class="token operator">==</span> c1 <span class="token keyword">and</span> s <span class="token operator">==</span> <span class="token number">1</span> <span class="token keyword">else</span> <span class="token boolean">None</span><span class="token punctuation">)</span>
self<span class="token punctuation">.</span>rbr_dense <span class="token operator">=</span> nn<span class="token punctuation">.</span>Sequential<span class="token punctuation">(</span>
nn<span class="token punctuation">.</span>Conv2d<span class="token punctuation">(</span>c1<span class="token punctuation">,</span> c2<span class="token punctuation">,</span> k<span class="token punctuation">,</span> s<span class="token punctuation">,</span> autopad<span class="token punctuation">(</span>k<span class="token punctuation">,</span> p<span class="token punctuation">)</span><span class="token punctuation">,</span> groups<span class="token operator">=</span>g<span class="token punctuation">,</span> bias<span class="token operator">=</span><span class="token boolean">False</span><span class="token punctuation">)</span><span class="token punctuation">,</span>
nn<span class="token punctuation">.</span>BatchNorm2d<span class="token punctuation">(</span>num_features<span class="token operator">=</span>c2<span class="token punctuation">,</span> eps<span class="token operator">=</span><span class="token number">0.001</span><span class="token punctuation">,</span> momentum<span class="token operator">=</span><span class="token number">0.03</span><span class="token punctuation">)</span><span class="token punctuation">,</span>
<span class="token punctuation">)</span>
self<span class="token punctuation">.</span>rbr_1x1 <span class="token operator">=</span> nn<span class="token punctuation">.</span>Sequential<span class="token punctuation">(</span>
nn<span class="token punctuation">.</span>Conv2d<span class="token punctuation">(</span> c1<span class="token punctuation">,</span> c2<span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">,</span> s<span class="token punctuation">,</span> padding_11<span class="token punctuation">,</span> groups<span class="token operator">=</span>g<span class="token punctuation">,</span> bias<span class="token operator">=</span><span class="token boolean">False</span><span class="token punctuation">)</span><span class="token punctuation">,</span>
nn<span class="token punctuation">.</span>BatchNorm2d<span class="token punctuation">(</span>num_features<span class="token operator">=</span>c2<span class="token punctuation">,</span> eps<span class="token operator">=</span><span class="token number">0.001</span><span class="token punctuation">,</span> momentum<span class="token operator">=</span><span class="token number">0.03</span><span class="token punctuation">)</span><span class="token punctuation">,</span>
<span class="token punctuation">)</span>
<span class="token keyword">def</span> <span class="token function">forward</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> inputs<span class="token punctuation">)</span><span class="token punctuation">:</span>
<span class="token keyword">if</span> <span class="token builtin">hasattr</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> <span class="token string">"rbr_reparam"</span><span class="token punctuation">)</span><span class="token punctuation">:</span>
<span class="token keyword">return</span> self<span class="token punctuation">.</span>act<span class="token punctuation">(</span>self<span class="token punctuation">.</span>rbr_reparam<span class="token punctuation">(</span>inputs<span class="token punctuation">)</span><span class="token punctuation">)</span>
<span class="token keyword">if</span> self<span class="token punctuation">.</span>rbr_identity <span class="token keyword">is</span> <span class="token boolean">None</span><span class="token punctuation">:</span>
id_out <span class="token operator">=</span> <span class="token number">0</span>
<span class="token keyword">else</span><span class="token punctuation">:</span>
id_out <span class="token operator">=</span> self<span class="token punctuation">.</span>rbr_identity<span class="token punctuation">(</span>inputs<span class="token punctuation">)</span>
<span class="token keyword">return</span> self<span class="token punctuation">.</span>act<span class="token punctuation">(</span>self<span class="token punctuation">.</span>rbr_dense<span class="token punctuation">(</span>inputs<span class="token punctuation">)</span> <span class="token operator">+</span> self<span class="token punctuation">.</span>rbr_1x1<span class="token punctuation">(</span>inputs<span class="token punctuation">)</span> <span class="token operator">+</span> id_out<span class="token punctuation">)</span>
<span class="token keyword">def</span> <span class="token function">get_equivalent_kernel_bias</span><span class="token punctuation">(</span>self<span class="token punctuation">)</span><span class="token punctuation">:</span>
kernel3x3<span class="token punctuation">,</span> bias3x3 <span class="token operator">=</span> self<span class="token punctuation">.</span>_fuse_bn_tensor<span class="token punctuation">(</span>self<span class="token punctuation">.</span>rbr_dense<span class="token punctuation">)</span>
kernel1x1<span class="token punctuation">,</span> bias1x1 <span class="token operator">=</span> self<span class="token punctuation">.</span>_fuse_bn_tensor<span class="token punctuation">(</span>self<span class="token punctuation">.</span>rbr_1x1<span class="token punctuation">)</span>
kernelid<span class="token punctuation">,</span> biasid <span class="token operator">=</span> self<span class="token punctuation">.</span>_fuse_bn_tensor<span class="token punctuation">(</span>self<span class="token punctuation">.</span>rbr_identity<span class="token punctuation">)</span>
<span class="token keyword">return</span> <span class="token punctuation">(</span>
kernel3x3 <span class="token operator">+</span> self<span class="token punctuation">.</span>_pad_1x1_to_3x3_tensor<span class="token punctuation">(</span>kernel1x1<span class="token punctuation">)</span> <span class="token operator">+</span> kernelid<span class="token punctuation">,</span>
bias3x3 <span class="token operator">+</span> bias1x1 <span class="token operator">+</span> biasid<span class="token punctuation">,</span>
<span class="token punctuation">)</span>
<span class="token keyword">def</span> <span class="token function">_pad_1x1_to_3x3_tensor</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> kernel1x1<span class="token punctuation">)</span><span class="token punctuation">:</span>
<span class="token keyword">if</span> kernel1x1 <span class="token keyword">is</span> <span class="token boolean">None</span><span class="token punctuation">:</span>
<span class="token keyword">return</span> <span class="token number">0</span>
<span class="token keyword">else</span><span class="token punctuation">:</span>
<span class="token keyword">return</span> nn<span class="token punctuation">.</span>functional<span class="token punctuation">.</span>pad<span class="token punctuation">(</span>kernel1x1<span class="token punctuation">,</span> <span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
<span class="token keyword">def</span> <span class="token function">_fuse_bn_tensor</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> branch<span class="token punctuation">)</span><span class="token punctuation">:</span>
<span class="token keyword">if</span> branch <span class="token keyword">is</span> <span class="token boolean">None</span><span class="token punctuation">:</span>
<span class="token keyword">return</span> <span class="token number">0</span><span class="token punctuation">,</span> <span class="token number">0</span>
<span class="token keyword">if</span> <span class="token builtin">isinstance</span><span class="token punctuation">(</span>branch<span class="token punctuation">,</span> nn<span class="token punctuation">.</span>Sequential<span class="token punctuation">)</span><span class="token punctuation">:</span>
kernel <span class="token operator">=</span> branch<span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">.</span>weight
running_mean <span class="token operator">=</span> branch<span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">.</span>running_mean
running_var <span class="token operator">=</span> branch<span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">.</span>running_var
gamma <span class="token operator">=</span> branch<span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">.</span>weight
beta <span class="token operator">=</span> branch<span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">.</span>bias
eps <span class="token operator">=</span> branch<span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">.</span>eps
<span class="token keyword">else</span><span class="token punctuation">:</span>
<span class="token keyword">assert</span> <span class="token builtin">isinstance</span><span class="token punctuation">(</span>branch<span class="token punctuation">,</span> nn<span class="token punctuation">.</span>BatchNorm2d<span class="token punctuation">)</span>
<span class="token keyword">if</span> <span class="token keyword">not</span> <span class="token builtin">hasattr</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> <span class="token string">"id_tensor"</span><span class="token punctuation">)</span><span class="token punctuation">:</span>
input_dim <span class="token operator">=</span> self<span class="token punctuation">.</span>in_channels <span class="token operator">//</span> self<span class="token punctuation">.</span>groups
kernel_value <span class="token operator">=</span> np<span class="token punctuation">.</span>zeros<span class="token punctuation">(</span>
<span class="token punctuation">(</span>self<span class="token punctuation">.</span>in_channels<span class="token punctuation">,</span> input_dim<span class="token punctuation">,</span> <span class="token number">3</span><span class="token punctuation">,</span> <span class="token number">3</span><span class="token punctuation">)</span><span class="token punctuation">,</span> dtype<span class="token operator">=</span>np<span class="token punctuation">.</span>float32
<span class="token punctuation">)</span>
<span class="token keyword">for</span> i <span class="token keyword">in</span> <span class="token builtin">range</span><span class="token punctuation">(</span>self<span class="token punctuation">.</span>in_channels<span class="token punctuation">)</span><span class="token punctuation">:</span>
kernel_value<span class="token punctuation">[</span>i<span class="token punctuation">,</span> i <span class="token operator">%</span> input_dim<span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">]</span> <span class="token operator">=</span> <span class="token number">1</span>
self<span class="token punctuation">.</span>id_tensor <span class="token operator">=</span> torch<span class="token punctuation">.</span>from_numpy<span class="token punctuation">(</span>kernel_value<span class="token punctuation">)</span><span class="token punctuation">.</span>to<span class="token punctuation">(</span>branch<span class="token punctuation">.</span>weight<span class="token punctuation">.</span>device<span class="token punctuation">)</span>
kernel <span class="token operator">=</span> self<span class="token punctuation">.</span>id_tensor
running_mean <span class="token operator">=</span> branch<span class="token punctuation">.</span>running_mean
running_var <span class="token operator">=</span> branch<span class="token punctuation">.</span>running_var
gamma <span class="token operator">=</span> branch<span class="token punctuation">.</span>weight
beta <span class="token operator">=</span> branch<span class="token punctuation">.</span>bias
eps <span class="token operator">=</span> branch<span class="token punctuation">.</span>eps
std <span class="token operator">=</span> <span class="token punctuation">(</span>running_var <span class="token operator">+</span> eps<span class="token punctuation">)</span><span class="token punctuation">.</span>sqrt<span class="token punctuation">(</span><span class="token punctuation">)</span>
t <span class="token operator">=</span> <span class="token punctuation">(</span>gamma <span class="token operator">/</span> std<span class="token punctuation">)</span><span class="token punctuation">.</span>reshape<span class="token punctuation">(</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span>
<span class="token keyword">return</span> kernel <span class="token operator">*</span> t<span class="token punctuation">,</span> beta <span class="token operator">-</span> running_mean <span class="token operator">*</span> gamma <span class="token operator">/</span> std
<span class="token keyword">def</span> <span class="token function">repvgg_convert</span><span class="token punctuation">(</span>self<span class="token punctuation">)</span><span class="token punctuation">:</span>
kernel<span class="token punctuation">,</span> bias <span class="token operator">=</span> self<span class="token punctuation">.</span>get_equivalent_kernel_bias<span class="token punctuation">(</span><span class="token punctuation">)</span>
<span class="token keyword">return</span> <span class="token punctuation">(</span>
kernel<span class="token punctuation">.</span>detach<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">.</span>cpu<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">.</span>numpy<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">,</span>
bias<span class="token punctuation">.</span>detach<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">.</span>cpu<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">.</span>numpy<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">,</span>
<span class="token punctuation">)</span>
<span class="token keyword">def</span> <span class="token function">fuse_conv_bn</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> conv<span class="token punctuation">,</span> bn<span class="token punctuation">)</span><span class="token punctuation">:</span>
std <span class="token operator">=</span> <span class="token punctuation">(</span>bn<span class="token punctuation">.</span>running_var <span class="token operator">+</span> bn<span class="token punctuation">.</span>eps<span class="token punctuation">)</span><span class="token punctuation">.</span>sqrt<span class="token punctuation">(</span><span class="token punctuation">)</span>
bias <span class="token operator">=</span> bn<span class="token punctuation">.</span>bias <span class="token operator">-</span> bn<span class="token punctuation">.</span>running_mean <span class="token operator">*</span> bn<span class="token punctuation">.</span>weight <span class="token operator">/</span> std
t <span class="token operator">=</span> <span class="token punctuation">(</span>bn<span class="token punctuation">.</span>weight <span class="token operator">/</span> std<span class="token punctuation">)</span><span class="token punctuation">.</span>reshape<span class="token punctuation">(</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span>
weights <span class="token operator">=</span> conv<span class="token punctuation">.</span>weight <span class="token operator">*</span> t
bn <span class="token operator">=</span> nn<span class="token punctuation">.</span>Identity<span class="token punctuation">(</span><span class="token punctuation">)</span>
conv <span class="token operator">=</span> nn<span class="token punctuation">.</span>Conv2d<span class="token punctuation">(</span>in_channels <span class="token operator">=</span> conv<span class="token punctuation">.</span>in_channels<span class="token punctuation">,</span>
out_channels <span class="token operator">=</span> conv<span class="token punctuation">.</span>out_channels<span class="token punctuation">,</span>
kernel_size <span class="token operator">=</span> conv<span class="token punctuation">.</span>kernel_size<span class="token punctuation">,</span>
stride<span class="token operator">=</span>conv<span class="token punctuation">.</span>stride<span class="token punctuation">,</span>
padding <span class="token operator">=</span> conv<span class="token punctuation">.</span>padding<span class="token punctuation">,</span>
dilation <span class="token operator">=</span> conv<span class="token punctuation">.</span>dilation<span class="token punctuation">,</span>
groups <span class="token operator">=</span> conv<span class="token punctuation">.</span>groups<span class="token punctuation">,</span>
bias <span class="token operator">=</span> <span class="token boolean">True</span><span class="token punctuation">,</span>
padding_mode <span class="token operator">=</span> conv<span class="token punctuation">.</span>padding_mode<span class="token punctuation">)</span>
conv<span class="token punctuation">.</span>weight <span class="token operator">=</span> torch<span class="token punctuation">.</span>nn<span class="token punctuation">.</span>Parameter<span class="token punctuation">(</span>weights<span class="token punctuation">)</span>
conv<span class="token punctuation">.</span>bias <span class="token operator">=</span> torch<span class="token punctuation">.</span>nn<span class="token punctuation">.</span>Parameter<span class="token punctuation">(</span>bias<span class="token punctuation">)</span>
<span class="token keyword">return</span> conv
<span class="token keyword">def</span> <span class="token function">fuse_repvgg_block</span><span class="token punctuation">(</span>self<span class="token punctuation">)</span><span class="token punctuation">:</span>
<span class="token keyword">if</span> self<span class="token punctuation">.</span>deploy<span class="token punctuation">:</span>
<span class="token keyword">return</span>
<span class="token keyword">print</span><span class="token punctuation">(</span><span class="token string-interpolation"><span class="token string">f"RepConv.fuse_repvgg_block"</span></span><span class="token punctuation">)</span>
self<span class="token punctuation">.</span>rbr_dense <span class="token operator">=</span> self<span class="token punctuation">.</span>fuse_conv_bn<span class="token punctuation">(</span>self<span class="token punctuation">.</span>rbr_dense<span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">,</span> self<span class="token punctuation">.</span>rbr_dense<span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
self<span class="token punctuation">.</span>rbr_1x1 <span class="token operator">=</span> self<span class="token punctuation">.</span>fuse_conv_bn<span class="token punctuation">(</span>self<span class="token punctuation">.</span>rbr_1x1<span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">,</span> self<span class="token punctuation">.</span>rbr_1x1<span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
rbr_1x1_bias <span class="token operator">=</span> self<span class="token punctuation">.</span>rbr_1x1<span class="token punctuation">.</span>bias
weight_1x1_expanded <span class="token operator">=</span> torch<span class="token punctuation">.</span>nn<span class="token punctuation">.</span>functional<span class="token punctuation">.</span>pad<span class="token punctuation">(</span>self<span class="token punctuation">.</span>rbr_1x1<span class="token punctuation">.</span>weight<span class="token punctuation">,</span> <span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
<span class="token comment"># Fuse self.rbr_identity</span>
<span class="token keyword">if</span> <span class="token punctuation">(</span><span class="token builtin">isinstance</span><span class="token punctuation">(</span>self<span class="token punctuation">.</span>rbr_identity<span class="token punctuation">,</span> nn<span class="token punctuation">.</span>BatchNorm2d<span class="token punctuation">)</span> <span class="token keyword">or</span> <span class="token builtin">isinstance</span><span class="token punctuation">(</span>self<span class="token punctuation">.</span>rbr_identity<span class="token punctuation">,</span> nn<span class="token punctuation">.</span>modules<span class="token punctuation">.</span>batchnorm<span class="token punctuation">.</span>SyncBatchNorm<span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">:</span>
identity_conv_1x1 <span class="token operator">=</span> nn<span class="token punctuation">.</span>Conv2d<span class="token punctuation">(</span>
in_channels<span class="token operator">=</span>self<span class="token punctuation">.</span>in_channels<span class="token punctuation">,</span>
out_channels<span class="token operator">=</span>self<span class="token punctuation">.</span>out_channels<span class="token punctuation">,</span>
kernel_size<span class="token operator">=</span><span class="token number">1</span><span class="token punctuation">,</span>
stride<span class="token operator">=</span><span class="token number">1</span><span class="token punctuation">,</span>
padding<span class="token operator">=</span><span class="token number">0</span><span class="token punctuation">,</span>
groups<span class="token operator">=</span>self<span class="token punctuation">.</span>groups<span class="token punctuation">,</span>
bias<span class="token operator">=</span><span class="token boolean">False</span><span class="token punctuation">)</span>
identity_conv_1x1<span class="token punctuation">.</span>weight<span class="token punctuation">.</span>data <span class="token operator">=</span> identity_conv_1x1<span class="token punctuation">.</span>weight<span class="token punctuation">.</span>data<span class="token punctuation">.</span>to<span class="token punctuation">(</span>self<span class="token punctuation">.</span>rbr_1x1<span class="token punctuation">.</span>weight<span class="token punctuation">.</span>data<span class="token punctuation">.</span>device<span class="token punctuation">)</span>
identity_conv_1x1<span class="token punctuation">.</span>weight<span class="token punctuation">.</span>data <span class="token operator">=</span> identity_conv_1x1<span class="token punctuation">.</span>weight<span class="token punctuation">.</span>data<span class="token punctuation">.</span>squeeze<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">.</span>squeeze<span class="token punctuation">(</span><span class="token punctuation">)</span>
identity_conv_1x1<span class="token punctuation">.</span>weight<span class="token punctuation">.</span>data<span class="token punctuation">.</span>fill_<span class="token punctuation">(</span><span class="token number">0.0</span><span class="token punctuation">)</span>
identity_conv_1x1<span class="token punctuation">.</span>weight<span class="token punctuation">.</span>data<span class="token punctuation">.</span>fill_diagonal_<span class="token punctuation">(</span><span class="token number">1.0</span><span class="token punctuation">)</span>
identity_conv_1x1<span class="token punctuation">.</span>weight<span class="token punctuation">.</span>data <span class="token operator">=</span> identity_conv_1x1<span class="token punctuation">.</span>weight<span class="token punctuation">.</span>data<span class="token punctuation">.</span>unsqueeze<span class="token punctuation">(</span><span class="token number">2</span><span class="token punctuation">)</span><span class="token punctuation">.</span>unsqueeze<span class="token punctuation">(</span><span class="token number">3</span><span class="token punctuation">)</span>
identity_conv_1x1 <span class="token operator">=</span> self<span class="token punctuation">.</span>fuse_conv_bn<span class="token punctuation">(</span>identity_conv_1x1<span class="token punctuation">,</span> self<span class="token punctuation">.</span>rbr_identity<span class="token punctuation">)</span>
bias_identity_expanded <span class="token operator">=</span> identity_conv_1x1<span class="token punctuation">.</span>bias
weight_identity_expanded <span class="token operator">=</span> torch<span class="token punctuation">.</span>nn<span class="token punctuation">.</span>functional<span class="token punctuation">.</span>pad<span class="token punctuation">(</span>identity_conv_1x1<span class="token punctuation">.</span>weight<span class="token punctuation">,</span> <span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
<span class="token keyword">else</span><span class="token punctuation">:</span>
bias_identity_expanded <span class="token operator">=</span> torch<span class="token punctuation">.</span>nn<span class="token punctuation">.</span>Parameter<span class="token punctuation">(</span> torch<span class="token punctuation">.</span>zeros_like<span class="token punctuation">(</span>rbr_1x1_bias<span class="token punctuation">)</span> <span class="token punctuation">)</span>
weight_identity_expanded <span class="token operator">=</span> torch<span class="token punctuation">.</span>nn<span class="token punctuation">.</span>Parameter<span class="token punctuation">(</span> torch<span class="token punctuation">.</span>zeros_like<span class="token punctuation">(</span>weight_1x1_expanded<span class="token punctuation">)</span> <span class="token punctuation">)</span>
self<span class="token punctuation">.</span>rbr_dense<span class="token punctuation">.</span>weight <span class="token operator">=</span> torch<span class="token punctuation">.</span>nn<span class="token punctuation">.</span>Parameter<span class="token punctuation">(</span>self<span class="token punctuation">.</span>rbr_dense<span class="token punctuation">.</span>weight <span class="token operator">+</span> weight_1x1_expanded <span class="token operator">+</span> weight_identity_expanded<span class="token punctuation">)</span>
self<span class="token punctuation">.</span>rbr_dense<span class="token punctuation">.</span>bias <span class="token operator">=</span> torch<span class="token punctuation">.</span>nn<span class="token punctuation">.</span>Parameter<span class="token punctuation">(</span>self<span class="token punctuation">.</span>rbr_dense<span class="token punctuation">.</span>bias <span class="token operator">+</span> rbr_1x1_bias <span class="token operator">+</span> bias_identity_expanded<span class="token punctuation">)</span>
self<span class="token punctuation">.</span>rbr_reparam <span class="token operator">=</span> self<span class="token punctuation">.</span>rbr_dense
self<span class="token punctuation">.</span>deploy <span class="token operator">=</span> <span class="token boolean">True</span>
<span class="token keyword">if</span> self<span class="token punctuation">.</span>rbr_identity <span class="token keyword">is</span> <span class="token keyword">not</span> <span class="token boolean">None</span><span class="token punctuation">:</span>
<span class="token keyword">del</span> self<span class="token punctuation">.</span>rbr_identity
self<span class="token punctuation">.</span>rbr_identity <span class="token operator">=</span> <span class="token boolean">None</span>
<span class="token keyword">if</span> self<span class="token punctuation">.</span>rbr_1x1 <span class="token keyword">is</span> <span class="token keyword">not</span> <span class="token boolean">None</span><span class="token punctuation">:</span>
<span class="token keyword">del</span> self<span class="token punctuation">.</span>rbr_1x1
self<span class="token punctuation">.</span>rbr_1x1 <span class="token operator">=</span> <span class="token boolean">None</span>
<span class="token keyword">if</span> self<span class="token punctuation">.</span>rbr_dense <span class="token keyword">is</span> <span class="token keyword">not</span> <span class="token boolean">None</span><span class="token punctuation">:</span>
<span class="token keyword">del</span> self<span class="token punctuation">.</span>rbr_dense
self<span class="token punctuation">.</span>rbr_dense <span class="token operator">=</span> <span class="token boolean">None</span>
#---------------------------------------------------#
# yolo_body
#---------------------------------------------------#
class YoloBody(nn.Module):
def init(self, anchors_mask, num_classes, phi, pretrained=False):
super(YoloBody, self).init()
#-----------------------------------------------#
# 定义了不同yolov7版本的参数
#-----------------------------------------------#
transition_channels = {
‘l’ : 32, ‘x’ : 40}[phi]
block_channels = 32
panet_channels = {
‘l’ : 32, ‘x’ : 64}[phi]
e = {
‘l’ : 2, ‘x’ : 1}[phi]
n = {
‘l’ : 4, ‘x’ : 6}[phi]
ids = {
‘l’ : [-1, -2, -3, -4, -5, -6], ‘x’ : [-1, -3, -5, -7, -8]}[phi]
conv = {
‘l’ : RepConv, ‘x’ : Conv}[phi]
#-----------------------------------------------#
# 输入图片是640, 640, 3
#-----------------------------------------------#
<span class="token comment">#---------------------------------------------------# </span>
<span class="token comment"># 生成主干模型</span>
<span class="token comment"># 获得三个有效特征层,他们的shape分别是:</span>
<span class="token comment"># 80, 80, 512</span>
<span class="token comment"># 40, 40, 1024</span>
<span class="token comment"># 20, 20, 1024</span>
<span class="token comment">#---------------------------------------------------#</span>
self<span class="token punctuation">.</span>backbone <span class="token operator">=</span> Backbone<span class="token punctuation">(</span>transition_channels<span class="token punctuation">,</span> block_channels<span class="token punctuation">,</span> n<span class="token punctuation">,</span> phi<span class="token punctuation">,</span> pretrained<span class="token operator">=</span>pretrained<span class="token punctuation">)</span>
self<span class="token punctuation">.</span>upsample <span class="token operator">=</span> nn<span class="token punctuation">.</span>Upsample<span class="token punctuation">(</span>scale_factor<span class="token operator">=</span><span class="token number">2</span><span class="token punctuation">,</span> mode<span class="token operator">=</span><span class="token string">"nearest"</span><span class="token punctuation">)</span>
self<span class="token punctuation">.</span>sppcspc <span class="token operator">=</span> SPPCSPC<span class="token punctuation">(</span>transition_channels <span class="token operator">*</span> <span class="token number">32</span><span class="token punctuation">,</span> transition_channels <span class="token operator">*</span> <span class="token number">16</span><span class="token punctuation">)</span>
self<span class="token punctuation">.</span>conv_for_P5 <span class="token operator">=</span> Conv<span class="token punctuation">(</span>transition_channels <span class="token operator">*</span> <span class="token number">16</span><span class="token punctuation">,</span> transition_channels <span class="token operator">*</span> <span class="token number">8</span><span class="token punctuation">)</span>
self<span class="token punctuation">.</span>conv_for_feat2 <span class="token operator">=</span> Conv<span class="token punctuation">(</span>transition_channels <span class="token operator">*</span> <span class="token number">32</span><span class="token punctuation">,</span> transition_channels <span class="token operator">*</span> <span class="token number">8</span><span class="token punctuation">)</span>
self<span class="token punctuation">.</span>conv3_for_upsample1 <span class="token operator">=</span> Multi_Concat_Block<span class="token punctuation">(</span>transition_channels <span class="token operator">*</span> <span class="token number">16</span><span class="token punctuation">,</span> panet_channels <span class="token operator">*</span> <span class="token number">4</span><span class="token punctuation">,</span> transition_channels <span class="token operator">*</span> <span class="token number">8</span><span class="token punctuation">,</span> e<span class="token operator">=</span>e<span class="token punctuation">,</span> n<span class="token operator">=</span>n<span class="token punctuation">,</span> ids<span class="token operator">=</span>ids<span class="token punctuation">)</span>
self<span class="token punctuation">.</span>conv_for_P4 <span class="token operator">=</span> Conv<span class="token punctuation">(</span>transition_channels <span class="token operator">*</span> <span class="token number">8</span><span class="token punctuation">,</span> transition_channels <span class="token operator">*</span> <span class="token number">4</span><span class="token punctuation">)</span>
self<span class="token punctuation">.</span>conv_for_feat1 <span class="token operator">=</span> Conv<span class="token punctuation">(</span>transition_channels <span class="token operator">*</span> <span class="token number">16</span><span class="token punctuation">,</span> transition_channels <span class="token operator">*</span> <span class="token number">4</span><span class="token punctuation">)</span>
self<span class="token punctuation">.</span>conv3_for_upsample2 <span class="token operator">=</span> Multi_Concat_Block<span class="token punctuation">(</span>transition_channels <span class="token operator">*</span> <span class="token number">8</span><span class="token punctuation">,</span> panet_channels <span class="token operator">*</span> <span class="token number">2</span><span class="token punctuation">,</span> transition_channels <span class="token operator">*</span> <span class="token number">4</span><span class="token punctuation">,</span> e<span class="token operator">=</span>e<span class="token punctuation">,</span> n<span class="token operator">=</span>n<span class="token punctuation">,</span> ids<span class="token operator">=</span>ids<span class="token punctuation">)</span>
self<span class="token punctuation">.</span>down_sample1 <span class="token operator">=</span> Transition_Block<span class="token punctuation">(</span>transition_channels <span class="token operator">*</span> <span class="token number">4</span><span class="token punctuation">,</span> transition_channels <span class="token operator">*</span> <span class="token number">4</span><span class="token punctuation">)</span>
self<span class="token punctuation">.</span>conv3_for_downsample1 <span class="token operator">=</span> Multi_Concat_Block<span class="token punctuation">(</span>transition_channels <span class="token operator">*</span> <span class="token number">16</span><span class="token punctuation">,</span> panet_channels <span class="token operator">*</span> <span class="token number">4</span><span class="token punctuation">,</span> transition_channels <span class="token operator">*</span> <span class="token number">8</span><span class="token punctuation">,</span> e<span class="token operator">=</span>e<span class="token punctuation">,</span> n<span class="token operator">=</span>n<span class="token punctuation">,</span> ids<span class="token operator">=</span>ids<span class="token punctuation">)</span>
self<span class="token punctuation">.</span>down_sample2 <span class="token operator">=</span> Transition_Block<span class="token punctuation">(</span>transition_channels <span class="token operator">*</span> <span class="token number">8</span><span class="token punctuation">,</span> transition_channels <span class="token operator">*</span> <span class="token number">8</span><span class="token punctuation">)</span>
self<span class="token punctuation">.</span>conv3_for_downsample2 <span class="token operator">=</span> Multi_Concat_Block<span class="token punctuation">(</span>transition_channels <span class="token operator">*</span> <span class="token number">32</span><span class="token punctuation">,</span> panet_channels <span class="token operator">*</span> <span class="token number">8</span><span class="token punctuation">,</span> transition_channels <span class="token operator">*</span> <span class="token number">16</span><span class="token punctuation">,</span> e<span class="token operator">=</span>e<span class="token punctuation">,</span> n<span class="token operator">=</span>n<span class="token punctuation">,</span> ids<span class="token operator">=</span>ids<span class="token punctuation">)</span>
self<span class="token punctuation">.</span>rep_conv_1 <span class="token operator">=</span> conv<span class="token punctuation">(</span>transition_channels <span class="token operator">*</span> <span class="token number">4</span><span class="token punctuation">,</span> transition_channels <span class="token operator">*</span> <span class="token number">8</span><span class="token punctuation">,</span> <span class="token number">3</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span>
self<span class="token punctuation">.</span>rep_conv_2 <span class="token operator">=</span> conv<span class="token punctuation">(</span>transition_channels <span class="token operator">*</span> <span class="token number">8</span><span class="token punctuation">,</span> transition_channels <span class="token operator">*</span> <span class="token number">16</span><span class="token punctuation">,</span> <span class="token number">3</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span>
self<span class="token punctuation">.</span>rep_conv_3 <span class="token operator">=</span> conv<span class="token punctuation">(</span>transition_channels <span class="token operator">*</span> <span class="token number">16</span><span class="token punctuation">,</span> transition_channels <span class="token operator">*</span> <span class="token number">32</span><span class="token punctuation">,</span> <span class="token number">3</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span>
self<span class="token punctuation">.</span>yolo_head_P3 <span class="token operator">=</span> nn<span class="token punctuation">.</span>Conv2d<span class="token punctuation">(</span>transition_channels <span class="token operator">*</span> <span class="token number">8</span><span class="token punctuation">,</span> <span class="token builtin">len</span><span class="token punctuation">(</span>anchors_mask<span class="token punctuation">[</span><span class="token number">2</span><span class="token punctuation">]</span><span class="token punctuation">)</span> <span class="token operator">*</span> <span class="token punctuation">(</span><span class="token number">5</span> <span class="token operator">+</span> num_classes<span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span>
self<span class="token punctuation">.</span>yolo_head_P4 <span class="token operator">=</span> nn<span class="token punctuation">.</span>Conv2d<span class="token punctuation">(</span>transition_channels <span class="token operator">*</span> <span class="token number">16</span><span class="token punctuation">,</span> <span class="token builtin">len</span><span class="token punctuation">(</span>anchors_mask<span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">)</span> <span class="token operator">*</span> <span class="token punctuation">(</span><span class="token number">5</span> <span class="token operator">+</span> num_classes<span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span>
self<span class="token punctuation">.</span>yolo_head_P5 <span class="token operator">=</span> nn<span class="token punctuation">.</span>Conv2d<span class="token punctuation">(</span>transition_channels <span class="token operator">*</span> <span class="token number">32</span><span class="token punctuation">,</span> <span class="token builtin">len</span><span class="token punctuation">(</span>anchors_mask<span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">)</span> <span class="token operator">*</span> <span class="token punctuation">(</span><span class="token number">5</span> <span class="token operator">+</span> num_classes<span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span>
<span class="token keyword">def</span> <span class="token function">fuse</span><span class="token punctuation">(</span>self<span class="token punctuation">)</span><span class="token punctuation">:</span>
<span class="token keyword">print</span><span class="token punctuation">(</span><span class="token string">'Fusing layers... '</span><span class="token punctuation">)</span>
<span class="token keyword">for</span> m <span class="token keyword">in</span> self<span class="token punctuation">.</span>modules<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">:</span>
<span class="token keyword">if</span> <span class="token builtin">isinstance</span><span class="token punctuation">(</span>m<span class="token punctuation">,</span> RepConv<span class="token punctuation">)</span><span class="token punctuation">:</span>
m<span class="token punctuation">.</span>fuse_repvgg_block<span class="token punctuation">(</span><span class="token punctuation">)</span>
<span class="token keyword">elif</span> <span class="token builtin">type</span><span class="token punctuation">(</span>m<span class="token punctuation">)</span> <span class="token keyword">is</span> Conv <span class="token keyword">and</span> <span class="token builtin">hasattr</span><span class="token punctuation">(</span>m<span class="token punctuation">,</span> <span class="token string">'bn'</span><span class="token punctuation">)</span><span class="token punctuation">:</span>
m<span class="token punctuation">.</span>conv <span class="token operator">=</span> fuse_conv_and_bn<span class="token punctuation">(</span>m<span class="token punctuation">.</span>conv<span class="token punctuation">,</span> m<span class="token punctuation">.</span>bn<span class="token punctuation">)</span>
<span class="token builtin">delattr</span><span class="token punctuation">(</span>m<span class="token punctuation">,</span> <span class="token string">'bn'</span><span class="token punctuation">)</span>
m<span class="token punctuation">.</span>forward <span class="token operator">=</span> m<span class="token punctuation">.</span>fuseforward
<span class="token keyword">return</span> self
<span class="token keyword">def</span> <span class="token function">forward</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> x<span class="token punctuation">)</span><span class="token punctuation">:</span>
<span class="token comment"># backbone</span>
feat1<span class="token punctuation">,</span> feat2<span class="token punctuation">,</span> feat3 <span class="token operator">=</span> self<span class="token punctuation">.</span>backbone<span class="token punctuation">.</span>forward<span class="token punctuation">(</span>x<span class="token punctuation">)</span>
P5 <span class="token operator">=</span> self<span class="token punctuation">.</span>sppcspc<span class="token punctuation">(</span>feat3<span class="token punctuation">)</span>
P5_conv <span class="token operator">=</span> self<span class="token punctuation">.</span>conv_for_P5<span class="token punctuation">(</span>P5<span class="token punctuation">)</span>
P5_upsample <span class="token operator">=</span> self<span class="token punctuation">.</span>upsample<span class="token punctuation">(</span>P5_conv<span class="token punctuation">)</span>
P4 <span class="token operator">=</span> torch<span class="token punctuation">.</span>cat<span class="token punctuation">(</span><span class="token punctuation">[</span>self<span class="token punctuation">.</span>conv_for_feat2<span class="token punctuation">(</span>feat2<span class="token punctuation">)</span><span class="token punctuation">,</span> P5_upsample<span class="token punctuation">]</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span>
P4 <span class="token operator">=</span> self<span class="token punctuation">.</span>conv3_for_upsample1<span class="token punctuation">(</span>P4<span class="token punctuation">)</span>
P4_conv <span class="token operator">=</span> self<span class="token punctuation">.</span>conv_for_P4<span class="token punctuation">(</span>P4<span class="token punctuation">)</span>
P4_upsample <span class="token operator">=</span> self<span class="token punctuation">.</span>upsample<span class="token punctuation">(</span>P4_conv<span class="token punctuation">)</span>
P3 <span class="token operator">=</span> torch<span class="token punctuation">.</span>cat<span class="token punctuation">(</span><span class="token punctuation">[</span>self<span class="token punctuation">.</span>conv_for_feat1<span class="token punctuation">(</span>feat1<span class="token punctuation">)</span><span class="token punctuation">,</span> P4_upsample<span class="token punctuation">]</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span>
P3 <span class="token operator">=</span> self<span class="token punctuation">.</span>conv3_for_upsample2<span class="token punctuation">(</span>P3<span class="token punctuation">)</span>
P3_downsample <span class="token operator">=</span> self<span class="token punctuation">.</span>down_sample1<span class="token punctuation">(</span>P3<span class="token punctuation">)</span>
P4 <span class="token operator">=</span> torch<span class="token punctuation">.</span>cat<span class="token punctuation">(</span><span class="token punctuation">[</span>P3_downsample<span class="token punctuation">,</span> P4<span class="token punctuation">]</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span>
P4 <span class="token operator">=</span> self<span class="token punctuation">.</span>conv3_for_downsample1<span class="token punctuation">(</span>P4<span class="token punctuation">)</span>
P4_downsample <span class="token operator">=</span> self<span class="token punctuation">.</span>down_sample2<span class="token punctuation">(</span>P4<span class="token punctuation">)</span>
P5 <span class="token operator">=</span> torch<span class="token punctuation">.</span>cat<span class="token punctuation">(</span><span class="token punctuation">[</span>P4_downsample<span class="token punctuation">,</span> P5<span class="token punctuation">]</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span>
P5 <span class="token operator">=</span> self<span class="token punctuation">.</span>conv3_for_downsample2<span class="token punctuation">(</span>P5<span class="token punctuation">)</span>
P3 <span class="token operator">=</span> self<span class="token punctuation">.</span>rep_conv_1<span class="token punctuation">(</span>P3<span class="token punctuation">)</span>
P4 <span class="token operator">=</span> self<span class="token punctuation">.</span>rep_conv_2<span class="token punctuation">(</span>P4<span class="token punctuation">)</span>
P5 <span class="token operator">=</span> self<span class="token punctuation">.</span>rep_conv_3<span class="token punctuation">(</span>P5<span class="token punctuation">)</span>
<span class="token comment">#---------------------------------------------------#</span>
<span class="token comment"># 第三个特征层</span>
<span class="token comment"># y3=(batch_size, 75, 80, 80)</span>
<span class="token comment">#---------------------------------------------------#</span>
out2 <span class="token operator">=</span> self<span class="token punctuation">.</span>yolo_head_P3<span class="token punctuation">(</span>P3<span class="token punctuation">)</span>
<span class="token comment">#---------------------------------------------------#</span>
<span class="token comment"># 第二个特征层</span>
<span class="token comment"># y2=(batch_size, 75, 40, 40)</span>
<span class="token comment">#---------------------------------------------------#</span>
out1 <span class="token operator">=</span> self<span class="token punctuation">.</span>yolo_head_P4<span class="token punctuation">(</span>P4<span class="token punctuation">)</span>
<span class="token comment">#---------------------------------------------------#</span>
<span class="token comment"># 第一个特征层</span>
<span class="token comment"># y1=(batch_size, 75, 20, 20)</span>
<span class="token comment">#---------------------------------------------------#</span>
out0 <span class="token operator">=</span> self<span class="token punctuation">.</span>yolo_head_P5<span class="token punctuation">(</span>P5<span class="token punctuation">)</span>
<span class="token keyword">return</span> <span class="token punctuation">[</span>out0<span class="token punctuation">,</span> out1<span class="token punctuation">,</span> out2<span class="token punctuation">]</span>
三、预测结果的解码
1、获得预测框与得分
由第二步我们可以获得三个特征层的预测结果,shape分别为(N,20,20,255),(N,40,40,255),(N,80,80,255)的数据。
但是这个预测结果并不对应着最终的预测框在图片上的位置,还需要解码才可以完成。在YoloV5里,每一个特征层上每一个特征点存在3个先验框。
每个特征层最后的255可以拆分成3个85,对应3个先验框的85个参数,我们先将其reshape一下,其结果为(N,20,20,3,85),(N,40.40,3,85),(N,80,80,3,85)。
其中的85可以拆分成4+1+80。
前4个参数用于判断每一个特征点的回归参数,回归参数调整后可以获得预测框;
第5个参数用于判断每一个特征点是否包含物体;
最后80个参数用于判断每一个特征点所包含的物体种类。
以(N,20,20,3,85)这个特征层为例,该特征层相当于将图像划分成20x20个特征点,如果某个特征点落在物体的对应框内,就用于预测该物体。
如图所示,蓝色的点为20x20的特征点,此时我们对左图黑色点的三个先验框进行解码操作演示:
1、进行中心预测点的计算,利用Regression预测结果前两个序号的内容对特征点的三个先验框中心坐标进行偏移,偏移后是右图红色的三个点;
2、进行预测框宽高的计算,利用Regression预测结果后两个序号的内容求指数后获得预测框的宽高;
3、此时获得的预测框就可以绘制在图片上了。
除去这样的解码操作,还有非极大抑制的操作需要进行,防止同一种类的框的堆积。
def decode_box(self, inputs): outputs = [] for i, input in enumerate(inputs): #-----------------------------------------------# # 输入的input一共有三个,他们的shape分别是 # batch_size, 255, 20, 20 # batch_size, 255, 40, 40 # batch_size, 255, 80, 80 #-----------------------------------------------# batch_size = input.size(0) input_height = input.size(2) input_width = input.size(3)
<span class="token comment">#-----------------------------------------------#</span> <span class="token comment"># 输入为416x416时</span> <span class="token comment"># stride_h = stride_w = 32、16、8</span> <span class="token comment">#-----------------------------------------------#</span> stride_h <span class="token operator">=</span> self<span class="token punctuation">.</span>input_shape<span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span> <span class="token operator">/</span> input_height stride_w <span class="token operator">=</span> self<span class="token punctuation">.</span>input_shape<span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">]</span> <span class="token operator">/</span> input_width <span class="token comment">#-------------------------------------------------#</span> <span class="token comment"># 此时获得的scaled_anchors大小是相对于特征层的</span> <span class="token comment">#-------------------------------------------------#</span> scaled_anchors <span class="token operator">=</span> <span class="token punctuation">[</span><span class="token punctuation">(</span>anchor_width <span class="token operator">/</span> stride_w<span class="token punctuation">,</span> anchor_height <span class="token operator">/</span> stride_h<span class="token punctuation">)</span> <span class="token keyword">for</span> anchor_width<span class="token punctuation">,</span> anchor_height <span class="token keyword">in</span> self<span class="token punctuation">.</span>anchors<span class="token punctuation">[</span>self<span class="token punctuation">.</span>anchors_mask<span class="token punctuation">[</span>i<span class="token punctuation">]</span><span class="token punctuation">]</span><span class="token punctuation">]</span> <span class="token comment">#-----------------------------------------------#</span> <span class="token comment"># 输入的input一共有三个,他们的shape分别是</span> <span class="token comment"># batch_size, 3, 20, 20, 85</span> <span class="token comment"># batch_size, 3, 40, 40, 85</span> <span class="token comment"># batch_size, 3, 80, 80, 85</span> <span class="token comment">#-----------------------------------------------#</span> prediction <span class="token operator">=</span> <span class="token builtin">input</span><span class="token punctuation">.</span>view<span class="token punctuation">(</span>batch_size<span class="token punctuation">,</span> <span class="token builtin">len</span><span class="token punctuation">(</span>self<span class="token punctuation">.</span>anchors_mask<span class="token punctuation">[</span>i<span class="token punctuation">]</span><span class="token punctuation">)</span><span class="token punctuation">,</span> self<span class="token punctuation">.</span>bbox_attrs<span class="token punctuation">,</span> input_height<span class="token punctuation">,</span> input_width<span class="token punctuation">)</span><span class="token punctuation">.</span>permute<span class="token punctuation">(</span><span class="token number">0</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">3</span><span class="token punctuation">,</span> <span class="token number">4</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">)</span><span class="token punctuation">.</span>contiguous<span class="token punctuation">(</span><span class="token punctuation">)</span> <span class="token comment">#-----------------------------------------------#</span> <span class="token comment"># 先验框的中心位置的调整参数</span> <span class="token comment">#-----------------------------------------------#</span> x <span class="token operator">=</span> torch<span class="token punctuation">.</span>sigmoid<span class="token punctuation">(</span>prediction<span class="token punctuation">[</span><span class="token punctuation">.</span><span class="token punctuation">.</span><span class="token punctuation">.</span><span class="token punctuation">,</span> <span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">)</span> y <span class="token operator">=</span> torch<span class="token punctuation">.</span>sigmoid<span class="token punctuation">(</span>prediction<span class="token punctuation">[</span><span class="token punctuation">.</span><span class="token punctuation">.</span><span class="token punctuation">.</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">)</span> <span class="token comment">#-----------------------------------------------#</span> <span class="token comment"># 先验框的宽高调整参数</span> <span class="token comment">#-----------------------------------------------#</span> w <span class="token operator">=</span> torch<span class="token punctuation">.</span>sigmoid<span class="token punctuation">(</span>prediction<span class="token punctuation">[</span><span class="token punctuation">.</span><span class="token punctuation">.</span><span class="token punctuation">.</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">]</span><span class="token punctuation">)</span> h <span class="token operator">=</span> torch<span class="token punctuation">.</span>sigmoid<span class="token punctuation">(</span>prediction<span class="token punctuation">[</span><span class="token punctuation">.</span><span class="token punctuation">.</span><span class="token punctuation">.</span><span class="token punctuation">,</span> <span class="token number">3</span><span class="token punctuation">]</span><span class="token punctuation">)</span> <span class="token comment">#-----------------------------------------------#</span> <span class="token comment"># 获得置信度,是否有物体</span> <span class="token comment">#-----------------------------------------------#</span> conf <span class="token operator">=</span> torch<span class="token punctuation">.</span>sigmoid<span class="token punctuation">(</span>prediction<span class="token punctuation">[</span><span class="token punctuation">.</span><span class="token punctuation">.</span><span class="token punctuation">.</span><span class="token punctuation">,</span> <span class="token number">4</span><span class="token punctuation">]</span><span class="token punctuation">)</span> <span class="token comment">#-----------------------------------------------#</span> <span class="token comment"># 种类置信度</span> <span class="token comment">#-----------------------------------------------#</span> pred_cls <span class="token operator">=</span> torch<span class="token punctuation">.</span>sigmoid<span class="token punctuation">(</span>prediction<span class="token punctuation">[</span><span class="token punctuation">.</span><span class="token punctuation">.</span><span class="token punctuation">.</span><span class="token punctuation">,</span> <span class="token number">5</span><span class="token punctuation">:</span><span class="token punctuation">]</span><span class="token punctuation">)</span> FloatTensor <span class="token operator">=</span> torch<span class="token punctuation">.</span>cuda<span class="token punctuation">.</span>FloatTensor <span class="token keyword">if</span> x<span class="token punctuation">.</span>is_cuda <span class="token keyword">else</span> torch<span class="token punctuation">.</span>FloatTensor LongTensor <span class="token operator">=</span> torch<span class="token punctuation">.</span>cuda<span class="token punctuation">.</span>LongTensor <span class="token keyword">if</span> x<span class="token punctuation">.</span>is_cuda <span class="token keyword">else</span> torch<span class="token punctuation">.</span>LongTensor <span class="token comment">#----------------------------------------------------------#</span> <span class="token comment"># 生成网格,先验框中心,网格左上角 </span> <span class="token comment"># batch_size,3,20,20</span> <span class="token comment">#----------------------------------------------------------#</span> grid_x <span class="token operator">=</span> torch<span class="token punctuation">.</span>linspace<span class="token punctuation">(</span><span class="token number">0</span><span class="token punctuation">,</span> input_width <span class="token operator">-</span> <span class="token number">1</span><span class="token punctuation">,</span> input_width<span class="token punctuation">)</span><span class="token punctuation">.</span>repeat<span class="token punctuation">(</span>input_height<span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">.</span>repeat<span class="token punctuation">(</span> batch_size <span class="token operator">*</span> <span class="token builtin">len</span><span class="token punctuation">(</span>self<span class="token punctuation">.</span>anchors_mask<span class="token punctuation">[</span>i<span class="token punctuation">]</span><span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">.</span>view<span class="token punctuation">(</span>x<span class="token punctuation">.</span>shape<span class="token punctuation">)</span><span class="token punctuation">.</span><span class="token builtin">type</span><span class="token punctuation">(</span>FloatTensor<span class="token punctuation">)</span> grid_y <span class="token operator">=</span> torch<span class="token punctuation">.</span>linspace<span class="token punctuation">(</span><span class="token number">0</span><span class="token punctuation">,</span> input_height <span class="token operator">-</span> <span class="token number">1</span><span class="token punctuation">,</span> input_height<span class="token punctuation">)</span><span class="token punctuation">.</span>repeat<span class="token punctuation">(</span>input_width<span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">.</span>t<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">.</span>repeat<span class="token punctuation">(</span> batch_size <span class="token operator">*</span> <span class="token builtin">len</span><span class="token punctuation">(</span>self<span class="token punctuation">.</span>anchors_mask<span class="token punctuation">[</span>i<span class="token punctuation">]</span><span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">.</span>view<span class="token punctuation">(</span>y<span class="token punctuation">.</span>shape<span class="token punctuation">)</span><span class="token punctuation">.</span><span class="token builtin">type</span><span class="token punctuation">(</span>FloatTensor<span class="token punctuation">)</span> <span class="token comment">#----------------------------------------------------------#</span> <span class="token comment"># 按照网格格式生成先验框的宽高</span> <span class="token comment"># batch_size,3,20,20</span> <span class="token comment">#----------------------------------------------------------#</span> anchor_w <span class="token operator">=</span> FloatTensor<span class="token punctuation">(</span>scaled_anchors<span class="token punctuation">)</span><span class="token punctuation">.</span>index_select<span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">,</span> LongTensor<span class="token punctuation">(</span><span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">)</span><span class="token punctuation">)</span> anchor_h <span class="token operator">=</span> FloatTensor<span class="token punctuation">(</span>scaled_anchors<span class="token punctuation">)</span><span class="token punctuation">.</span>index_select<span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">,</span> LongTensor<span class="token punctuation">(</span><span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">)</span><span class="token punctuation">)</span> anchor_w <span class="token operator">=</span> anchor_w<span class="token punctuation">.</span>repeat<span class="token punctuation">(</span>batch_size<span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">.</span>repeat<span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">,</span> input_height <span class="token operator">*</span> input_width<span class="token punctuation">)</span><span class="token punctuation">.</span>view<span class="token punctuation">(</span>w<span class="token punctuation">.</span>shape<span class="token punctuation">)</span> anchor_h <span class="token operator">=</span> anchor_h<span class="token punctuation">.</span>repeat<span class="token punctuation">(</span>batch_size<span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">.</span>repeat<span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">,</span> input_height <span class="token operator">*</span> input_width<span class="token punctuation">)</span><span class="token punctuation">.</span>view<span class="token punctuation">(</span>h<span class="token punctuation">.</span>shape<span class="token punctuation">)</span> <span class="token comment">#----------------------------------------------------------#</span> <span class="token comment"># 利用预测结果对先验框进行调整</span> <span class="token comment"># 首先调整先验框的中心,从先验框中心向右下角偏移</span> <span class="token comment"># 再调整先验框的宽高。</span> <span class="token comment">#----------------------------------------------------------#</span> pred_boxes <span class="token operator">=</span> FloatTensor<span class="token punctuation">(</span>prediction<span class="token punctuation">[</span><span class="token punctuation">.</span><span class="token punctuation">.</span><span class="token punctuation">.</span><span class="token punctuation">,</span> <span class="token punctuation">:</span><span class="token number">4</span><span class="token punctuation">]</span><span class="token punctuation">.</span>shape<span class="token punctuation">)</span> pred_boxes<span class="token punctuation">[</span><span class="token punctuation">.</span><span class="token punctuation">.</span><span class="token punctuation">.</span><span class="token punctuation">,</span> <span class="token number">0</span><span class="token punctuation">]</span> <span class="token operator">=</span> x<span class="token punctuation">.</span>data <span class="token operator">*</span> <span class="token number">2.</span> <span class="token operator">-</span> <span class="token number">0.5</span> <span class="token operator">+</span> grid_x pred_boxes<span class="token punctuation">[</span><span class="token punctuation">.</span><span class="token punctuation">.</span><span class="token punctuation">.</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">]</span> <span class="token operator">=</span> y<span class="token punctuation">.</span>data <span class="token operator">*</span> <span class="token number">2.</span> <span class="token operator">-</span> <span class="token number">0.5</span> <span class="token operator">+</span> grid_y pred_boxes<span class="token punctuation">[</span><span class="token punctuation">.</span><span class="token punctuation">.</span><span class="token punctuation">.</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">]</span> <span class="token operator">=</span> <span class="token punctuation">(</span>w<span class="token punctuation">.</span>data <span class="token operator">*</span> <span class="token number">2</span><span class="token punctuation">)</span> <span class="token operator">**</span> <span class="token number">2</span> <span class="token operator">*</span> anchor_w pred_boxes<span class="token punctuation">[</span><span class="token punctuation">.</span><span class="token punctuation">.</span><span class="token punctuation">.</span><span class="token punctuation">,</span> <span class="token number">3</span><span class="token punctuation">]</span> <span class="token operator">=</span> <span class="token punctuation">(</span>h<span class="token punctuation">.</span>data <span class="token operator">*</span> <span class="token number">2</span><span class="token punctuation">)</span> <span class="token operator">**</span> <span class="token number">2</span> <span class="token operator">*</span> anchor_h <span class="token comment">#----------------------------------------------------------#</span> <span class="token comment"># 将输出结果归一化成小数的形式</span> <span class="token comment">#----------------------------------------------------------#</span> _scale <span class="token operator">=</span> torch<span class="token punctuation">.</span>Tensor<span class="token punctuation">(</span><span class="token punctuation">[</span>input_width<span class="token punctuation">,</span> input_height<span class="token punctuation">,</span> input_width<span class="token punctuation">,</span> input_height<span class="token punctuation">]</span><span class="token punctuation">)</span><span class="token punctuation">.</span><span class="token builtin">type</span><span class="token punctuation">(</span>FloatTensor<span class="token punctuation">)</span> output <span class="token operator">=</span> torch<span class="token punctuation">.</span>cat<span class="token punctuation">(</span><span class="token punctuation">(</span>pred_boxes<span class="token punctuation">.</span>view<span class="token punctuation">(</span>batch_size<span class="token punctuation">,</span> <span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">4</span><span class="token punctuation">)</span> <span class="token operator">/</span> _scale<span class="token punctuation">,</span> conf<span class="token punctuation">.</span>view<span class="token punctuation">(</span>batch_size<span class="token punctuation">,</span> <span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">,</span> pred_cls<span class="token punctuation">.</span>view<span class="token punctuation">(</span>batch_size<span class="token punctuation">,</span> <span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">,</span> self<span class="token punctuation">.</span>num_classes<span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">)</span> outputs<span class="token punctuation">.</span>append<span class="token punctuation">(</span>output<span class="token punctuation">.</span>data<span class="token punctuation">)</span> <span class="token keyword">return</span> outputs
2、得分筛选与非极大抑制
得到最终的预测结果后还要进行得分排序与非极大抑制筛选。
得分筛选就是筛选出得分满足confidence置信度的预测框。
非极大抑制就是筛选出一定区域内属于同一种类得分最大的框。
得分筛选与非极大抑制的过程可以概括如下:
1、找出该图片中得分大于门限函数的框。在进行重合框筛选前就进行得分的筛选可以大幅度减少框的数量。
2、对种类进行循环,非极大抑制的作用是筛选出一定区域内属于同一种类得分最大的框,对种类进行循环可以帮助我们对每一个类分别进行非极大抑制。
3、根据得分对该种类进行从大到小排序。
4、每次取出得分最大的框,计算其与其它所有预测框的重合程度,重合程度过大的则剔除。
得分筛选与非极大抑制后的结果就可以用于绘制预测框了。
下图是经过非极大抑制的。
下图是未经过非极大抑制的。
实现代码为:
def non_max_suppression(self, prediction, num_classes, input_shape, image_shape, letterbox_image, conf_thres=0.5, nms_thres=0.4): #----------------------------------------------------------# # 将预测结果的格式转换成左上角右下角的格式。 # prediction [batch_size, num_anchors, 85] #----------------------------------------------------------# box_corner = prediction.new(prediction.shape) box_corner[:, :, 0] = prediction[:, :, 0] - prediction[:, :, 2] / 2 box_corner[:, :, 1] = prediction[:, :, 1] - prediction[:, :, 3] / 2 box_corner[:, :, 2] = prediction[:, :, 0] + prediction[:, :, 2] / 2 box_corner[:, :, 3] = prediction[:, :, 1] + prediction[:, :, 3] / 2 prediction[:, :, :4] = box_corner[:, :, :4]
output <span class="token operator">=</span> <span class="token punctuation">[</span><span class="token boolean">None</span> <span class="token keyword">for</span> _ <span class="token keyword">in</span> <span class="token builtin">range</span><span class="token punctuation">(</span><span class="token builtin">len</span><span class="token punctuation">(</span>prediction<span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">]</span> <span class="token keyword">for</span> i<span class="token punctuation">,</span> image_pred <span class="token keyword">in</span> <span class="token builtin">enumerate</span><span class="token punctuation">(</span>prediction<span class="token punctuation">)</span><span class="token punctuation">:</span> <span class="token comment">#----------------------------------------------------------#</span> <span class="token comment"># 对种类预测部分取max。</span> <span class="token comment"># class_conf [num_anchors, 1] 种类置信度</span> <span class="token comment"># class_pred [num_anchors, 1] 种类</span> <span class="token comment">#----------------------------------------------------------#</span> class_conf<span class="token punctuation">,</span> class_pred <span class="token operator">=</span> torch<span class="token punctuation">.</span><span class="token builtin">max</span><span class="token punctuation">(</span>image_pred<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token number">5</span><span class="token punctuation">:</span><span class="token number">5</span> <span class="token operator">+</span> num_classes<span class="token punctuation">]</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">,</span> keepdim<span class="token operator">=</span><span class="token boolean">True</span><span class="token punctuation">)</span> <span class="token comment">#----------------------------------------------------------#</span> <span class="token comment"># 利用置信度进行第一轮筛选</span> <span class="token comment">#----------------------------------------------------------#</span> conf_mask <span class="token operator">=</span> <span class="token punctuation">(</span>image_pred<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token number">4</span><span class="token punctuation">]</span> <span class="token operator">*</span> class_conf<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token number">0</span><span class="token punctuation">]</span> <span class="token operator">>=</span> conf_thres<span class="token punctuation">)</span><span class="token punctuation">.</span>squeeze<span class="token punctuation">(</span><span class="token punctuation">)</span> <span class="token comment">#----------------------------------------------------------#</span> <span class="token comment"># 根据置信度进行预测结果的筛选</span> <span class="token comment">#----------------------------------------------------------#</span> image_pred <span class="token operator">=</span> image_pred<span class="token punctuation">[</span>conf_mask<span class="token punctuation">]</span> class_conf <span class="token operator">=</span> class_conf<span class="token punctuation">[</span>conf_mask<span class="token punctuation">]</span> class_pred <span class="token operator">=</span> class_pred<span class="token punctuation">[</span>conf_mask<span class="token punctuation">]</span> <span class="token keyword">if</span> <span class="token keyword">not</span> image_pred<span class="token punctuation">.</span>size<span class="token punctuation">(</span><span class="token number">0</span><span class="token punctuation">)</span><span class="token punctuation">:</span> <span class="token keyword">continue</span> <span class="token comment">#-------------------------------------------------------------------------#</span> <span class="token comment"># detections [num_anchors, 7]</span> <span class="token comment"># 7的内容为:x1, y1, x2, y2, obj_conf, class_conf, class_pred</span> <span class="token comment">#-------------------------------------------------------------------------#</span> detections <span class="token operator">=</span> torch<span class="token punctuation">.</span>cat<span class="token punctuation">(</span><span class="token punctuation">(</span>image_pred<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token punctuation">:</span><span class="token number">5</span><span class="token punctuation">]</span><span class="token punctuation">,</span> class_conf<span class="token punctuation">.</span><span class="token builtin">float</span><span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">,</span> class_pred<span class="token punctuation">.</span><span class="token builtin">float</span><span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span> <span class="token comment">#------------------------------------------#</span> <span class="token comment"># 获得预测结果中包含的所有种类</span> <span class="token comment">#------------------------------------------#</span> unique_labels <span class="token operator">=</span> detections<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">.</span>cpu<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">.</span>unique<span class="token punctuation">(</span><span class="token punctuation">)</span> <span class="token keyword">if</span> prediction<span class="token punctuation">.</span>is_cuda<span class="token punctuation">:</span> unique_labels <span class="token operator">=</span> unique_labels<span class="token punctuation">.</span>cuda<span class="token punctuation">(</span><span class="token punctuation">)</span> detections <span class="token operator">=</span> detections<span class="token punctuation">.</span>cuda<span class="token punctuation">(</span><span class="token punctuation">)</span> <span class="token keyword">for</span> c <span class="token keyword">in</span> unique_labels<span class="token punctuation">:</span> <span class="token comment">#------------------------------------------#</span> <span class="token comment"># 获得某一类得分筛选后全部的预测结果</span> <span class="token comment">#------------------------------------------#</span> detections_class <span class="token operator">=</span> detections<span class="token punctuation">[</span>detections<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">]</span> <span class="token operator">==</span> c<span class="token punctuation">]</span> <span class="token comment">#------------------------------------------#</span> <span class="token comment"># 使用官方自带的非极大抑制会速度更快一些!</span> <span class="token comment">#------------------------------------------#</span> keep <span class="token operator">=</span> nms<span class="token punctuation">(</span> detections_class<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token punctuation">:</span><span class="token number">4</span><span class="token punctuation">]</span><span class="token punctuation">,</span> detections_class<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token number">4</span><span class="token punctuation">]</span> <span class="token operator">*</span> detections_class<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token number">5</span><span class="token punctuation">]</span><span class="token punctuation">,</span> nms_thres <span class="token punctuation">)</span> max_detections <span class="token operator">=</span> detections_class<span class="token punctuation">[</span>keep<span class="token punctuation">]</span> <span class="token comment"># # 按照存在物体的置信度排序</span> <span class="token comment"># _, conf_sort_index = torch.sort(detections_class[:, 4]*detections_class[:, 5], descending=True)</span> <span class="token comment"># detections_class = detections_class[conf_sort_index]</span> <span class="token comment"># # 进行非极大抑制</span> <span class="token comment"># max_detections = []</span> <span class="token comment"># while detections_class.size(0):</span> <span class="token comment"># # 取出这一类置信度最高的,一步一步往下判断,判断重合程度是否大于nms_thres,如果是则去除掉</span> <span class="token comment"># max_detections.append(detections_class[0].unsqueeze(0))</span> <span class="token comment"># if len(detections_class) == 1:</span> <span class="token comment"># break</span> <span class="token comment"># ious = bbox_iou(max_detections[-1], detections_class[1:])</span> <span class="token comment"># detections_class = detections_class[1:][ious < nms_thres]</span> <span class="token comment"># # 堆叠</span> <span class="token comment"># max_detections = torch.cat(max_detections).data</span> <span class="token comment"># Add max detections to outputs</span> output<span class="token punctuation">[</span>i<span class="token punctuation">]</span> <span class="token operator">=</span> max_detections <span class="token keyword">if</span> output<span class="token punctuation">[</span>i<span class="token punctuation">]</span> <span class="token keyword">is</span> <span class="token boolean">None</span> <span class="token keyword">else</span> torch<span class="token punctuation">.</span>cat<span class="token punctuation">(</span><span class="token punctuation">(</span>output<span class="token punctuation">[</span>i<span class="token punctuation">]</span><span class="token punctuation">,</span> max_detections<span class="token punctuation">)</span><span class="token punctuation">)</span> <span class="token keyword">if</span> output<span class="token punctuation">[</span>i<span class="token punctuation">]</span> <span class="token keyword">is</span> <span class="token keyword">not</span> <span class="token boolean">None</span><span class="token punctuation">:</span> output<span class="token punctuation">[</span>i<span class="token punctuation">]</span> <span class="token operator">=</span> output<span class="token punctuation">[</span>i<span class="token punctuation">]</span><span class="token punctuation">.</span>cpu<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">.</span>numpy<span class="token punctuation">(</span><span class="token punctuation">)</span> box_xy<span class="token punctuation">,</span> box_wh <span class="token operator">=</span> <span class="token punctuation">(</span>output<span class="token punctuation">[</span>i<span class="token punctuation">]</span><span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token number">0</span><span class="token punctuation">:</span><span class="token number">2</span><span class="token punctuation">]</span> <span class="token operator">+</span> output<span class="token punctuation">[</span>i<span class="token punctuation">]</span><span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">:</span><span class="token number">4</span><span class="token punctuation">]</span><span class="token punctuation">)</span><span class="token operator">/</span><span class="token number">2</span><span class="token punctuation">,</span> output<span class="token punctuation">[</span>i<span class="token punctuation">]</span><span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">:</span><span class="token number">4</span><span class="token punctuation">]</span> <span class="token operator">-</span> output<span class="token punctuation">[</span>i<span class="token punctuation">]</span><span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token number">0</span><span class="token punctuation">:</span><span class="token number">2</span><span class="token punctuation">]</span> output<span class="token punctuation">[</span>i<span class="token punctuation">]</span><span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token punctuation">:</span><span class="token number">4</span><span class="token punctuation">]</span> <span class="token operator">=</span> self<span class="token punctuation">.</span>yolo_correct_boxes<span class="token punctuation">(</span>box_xy<span class="token punctuation">,</span> box_wh<span class="token punctuation">,</span> input_shape<span class="token punctuation">,</span> image_shape<span class="token punctuation">,</span> letterbox_image<span class="token punctuation">)</span> <span class="token keyword">return</span> output
四、训练部分
1、计算loss所需内容
计算loss实际上是网络的预测结果和网络的真实结果的对比。
和网络的预测结果一样,网络的损失也由三个部分组成,分别是Reg部分、Obj部分、Cls部分。Reg部分是特征点的回归参数判断、Obj部分是特征点是否包含物体判断、Cls部分是特征点包含的物体的种类。
2、正样本的匹配过程
在YoloV7中,训练时正样本的匹配过程可以分为两部分。
a、对每个真实框通过坐标与宽高粗略匹配先验框与特征点。
b、使用SimOTA自适应精确选取每个真实框对应多少个先验框。
所谓正样本匹配,就是寻找哪些先验框被认为有对应的真实框,并且负责这个真实框的预测。
a、匹配先验框与特征点
在该部分中,YoloV7会对每个真实框进行粗匹配。找到哪些特征点上的哪些先验框可以负责该真实框的预测。
首先进行先验框的匹配,在YoloV7网络中,一共设计了9个不同大小的先验框。每个输出的特征层对应3个先验框。
对于任何一个真实框gt,YoloV7不再使用iou进行正样本的匹配,而是直接采用高宽比进行匹配,即使用真实框和9个不同大小的先验框计算宽高比。
如果真实框与某个先验框的宽高比例大于设定阈值,则说明该真实框和该先验框匹配度不够,将该先验框认为是负样本。
比如此时有一个真实框,它的宽高为[200, 200],是一个正方形。YoloV7默认设置的9个先验框为[12, 16], [19, 36], [40, 28], [36, 75], [76, 55], [72, 146], [142, 110], [192, 243], [459, 401]。设定阈值门限为4。
此时我们需要计算该真实框和9个先验框的宽高比例。比较宽高时存在两个情况,一个是真实框的宽高比先验框大,一个是先验框的宽高比真实框大。因此我们需要同时计算:真实框的宽高/先验框的宽高;先验框的宽高/真实框的宽高。然后在这其中选取最大值。
下个列表就是比较结果,这是一个shape为[9, 4]的矩阵,9代表9个先验框,4代表真实框的宽高/先验框的宽高;先验框的宽高/真实框的宽高。
[[16.66666667 12.5 0.06 0.08 ]
[10.52631579 5.55555556 0.095 0.18 ]
[ 5. 7.14285714 0.2 0.14 ]
[ 5.55555556 2.66666667 0.18 0.375 ]
[ 2.63157895 3.63636364 0.38 0.275 ]
[ 2.77777778 1.36986301 0.36 0.73 ]
[ 1.4084507 1.81818182 0.71 0.55 ]
[ 1.04166667 0.82304527 0.96 1.215 ]
[ 0.43572985 0.49875312 2.295 2.005 ]]
然后对每个先验框的比较结果取最大值。获得下述矩阵:
[16.66666667 10.52631579 7.14285714 5.55555556 3.63636364 2.77777778
1.81818182 1.215 2.295 ]
之后我们判断,哪些先验框的比较结果的值小于门限。可以知道[76, 55], [72, 146], [142, 110], [192, 243], [459, 401]五个先验框均满足需求。
[142, 110], [192, 243], [459, 401]属于20,20的特征层。
[76, 55], [72, 146]属于40,40的特征层。
此时我们已经可以判断哪些大小的先验框可用于该真实框的预测。
在YoloV5过去的Yolo中,每个真实框由其中心点所在的网格内的左上角特征点来负责预测。
在YoloV7中,同YoloV5,对于被选中的特征层,首先计算真实框落在哪个网格内,此时该网格左上角特征点便是一个负责预测的特征点。
同时利用四舍五入规则,找出最近的两个网格,将这三个网格都认为是负责预测该真实框的。
红色点表示该真实框的中心,除了当前所处的网格外,其2个最近的邻域网格也被选中。从这里就可以发现预测框的XY轴偏移部分的取值范围不再是0-1,而是0.5-1.5。
找到对应特征点后,对应特征点在满足宽高比的先验框负责该真实框的预测。
但这一步仅仅是粗略的筛选,后面我们会通过simOTA来精确筛选。
def find_3_positive(self, predictions, targets): #------------------------------------# # 获得每个特征层先验框的数量 # 与真实框的数量 #------------------------------------# num_anchor, num_gt = len(self.anchors_mask[0]), targets.shape[0] #------------------------------------# # 创建空列表存放indices和anchors #------------------------------------# indices, anchors = [], [] #------------------------------------# # 创建7个1 # 序号0,1为1 # 序号2:6为特征层的高宽 # 序号6为1 #------------------------------------# gain = torch.ones(7, device=targets.device) #------------------------------------# # ai [num_anchor, num_gt] # targets [num_gt, 6] => [num_anchor, num_gt, 7] #------------------------------------# ai = torch.arange(num_anchor, device=targets.device).float().view(num_anchor, 1).repeat(1, num_gt) targets = torch.cat((targets.repeat(num_anchor, 1, 1), ai[:, :, None]), 2) # append anchor indices
g <span class="token operator">=</span> <span class="token number">0.5</span> <span class="token comment"># offsets</span> off <span class="token operator">=</span> torch<span class="token punctuation">.</span>tensor<span class="token punctuation">(</span><span class="token punctuation">[</span> <span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">,</span> <span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">,</span> <span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">,</span> <span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">,</span> <span class="token punctuation">[</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">,</span> <span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">,</span> <span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">,</span> <span class="token comment"># j,k,l,m</span> <span class="token comment"># [1, 1], [1, -1], [-1, 1], [-1, -1], # jk,jm,lk,lm</span> <span class="token punctuation">]</span><span class="token punctuation">,</span> device<span class="token operator">=</span>targets<span class="token punctuation">.</span>device<span class="token punctuation">)</span><span class="token punctuation">.</span><span class="token builtin">float</span><span class="token punctuation">(</span><span class="token punctuation">)</span> <span class="token operator">*</span> g <span class="token keyword">for</span> i <span class="token keyword">in</span> <span class="token builtin">range</span><span class="token punctuation">(</span><span class="token builtin">len</span><span class="token punctuation">(</span>predictions<span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">:</span> <span class="token comment">#----------------------------------------------------#</span> <span class="token comment"># 将先验框除以stride,获得相对于特征层的先验框。</span> <span class="token comment"># anchors_i [num_anchor, 2]</span> <span class="token comment">#----------------------------------------------------#</span> anchors_i <span class="token operator">=</span> torch<span class="token punctuation">.</span>from_numpy<span class="token punctuation">(</span>self<span class="token punctuation">.</span>anchors<span class="token punctuation">[</span>i<span class="token punctuation">]</span> <span class="token operator">/</span> self<span class="token punctuation">.</span>stride<span class="token punctuation">[</span>i<span class="token punctuation">]</span><span class="token punctuation">)</span><span class="token punctuation">.</span>type_as<span class="token punctuation">(</span>predictions<span class="token punctuation">[</span>i<span class="token punctuation">]</span><span class="token punctuation">)</span> <span class="token comment">#-------------------------------------------#</span> <span class="token comment"># 计算获得对应特征层的高宽</span> <span class="token comment">#-------------------------------------------#</span> gain<span class="token punctuation">[</span><span class="token number">2</span><span class="token punctuation">:</span><span class="token number">6</span><span class="token punctuation">]</span> <span class="token operator">=</span> torch<span class="token punctuation">.</span>tensor<span class="token punctuation">(</span>predictions<span class="token punctuation">[</span>i<span class="token punctuation">]</span><span class="token punctuation">.</span>shape<span class="token punctuation">)</span><span class="token punctuation">[</span><span class="token punctuation">[</span><span class="token number">3</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">,</span> <span class="token number">3</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">]</span><span class="token punctuation">]</span> <span class="token comment">#-------------------------------------------#</span> <span class="token comment"># 将真实框乘上gain,</span> <span class="token comment"># 其实就是将真实框映射到特征层上</span> <span class="token comment">#-------------------------------------------#</span> t <span class="token operator">=</span> targets <span class="token operator">*</span> gain <span class="token keyword">if</span> num_gt<span class="token punctuation">:</span> <span class="token comment">#-------------------------------------------#</span> <span class="token comment"># 计算真实框与先验框高宽的比值</span> <span class="token comment"># 然后根据比值大小进行判断,</span> <span class="token comment"># 判断结果用于取出,获得所有先验框对应的真实框</span> <span class="token comment"># r [num_anchor, num_gt, 2]</span> <span class="token comment"># t [num_anchor, num_gt, 7] => [num_matched_anchor, 7]</span> <span class="token comment">#-------------------------------------------#</span> r <span class="token operator">=</span> t<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token number">4</span><span class="token punctuation">:</span><span class="token number">6</span><span class="token punctuation">]</span> <span class="token operator">/</span> anchors_i<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token boolean">None</span><span class="token punctuation">]</span> j <span class="token operator">=</span> torch<span class="token punctuation">.</span><span class="token builtin">max</span><span class="token punctuation">(</span>r<span class="token punctuation">,</span> <span class="token number">1.</span> <span class="token operator">/</span> r<span class="token punctuation">)</span><span class="token punctuation">.</span><span class="token builtin">max</span><span class="token punctuation">(</span><span class="token number">2</span><span class="token punctuation">)</span><span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span> <span class="token operator"><</span> self<span class="token punctuation">.</span>threshold t <span class="token operator">=</span> t<span class="token punctuation">[</span>j<span class="token punctuation">]</span> <span class="token comment"># filter</span> <span class="token comment">#-------------------------------------------#</span> <span class="token comment"># gxy 获得所有先验框对应的真实框的x轴y轴坐标</span> <span class="token comment"># gxi 取相对于该特征层的右小角的坐标</span> <span class="token comment">#-------------------------------------------#</span> gxy <span class="token operator">=</span> t<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">:</span><span class="token number">4</span><span class="token punctuation">]</span> <span class="token comment"># grid xy</span> gxi <span class="token operator">=</span> gain<span class="token punctuation">[</span><span class="token punctuation">[</span><span class="token number">2</span><span class="token punctuation">,</span> <span class="token number">3</span><span class="token punctuation">]</span><span class="token punctuation">]</span> <span class="token operator">-</span> gxy <span class="token comment"># inverse</span> j<span class="token punctuation">,</span> k <span class="token operator">=</span> <span class="token punctuation">(</span><span class="token punctuation">(</span>gxy <span class="token operator">%</span> <span class="token number">1.</span> <span class="token operator"><</span> g<span class="token punctuation">)</span> <span class="token operator">&</span> <span class="token punctuation">(</span>gxy <span class="token operator">></span> <span class="token number">1.</span><span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">.</span>T l<span class="token punctuation">,</span> m <span class="token operator">=</span> <span class="token punctuation">(</span><span class="token punctuation">(</span>gxi <span class="token operator">%</span> <span class="token number">1.</span> <span class="token operator"><</span> g<span class="token punctuation">)</span> <span class="token operator">&</span> <span class="token punctuation">(</span>gxi <span class="token operator">></span> <span class="token number">1.</span><span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">.</span>T j <span class="token operator">=</span> torch<span class="token punctuation">.</span>stack<span class="token punctuation">(</span><span class="token punctuation">(</span>torch<span class="token punctuation">.</span>ones_like<span class="token punctuation">(</span>j<span class="token punctuation">)</span><span class="token punctuation">,</span> j<span class="token punctuation">,</span> k<span class="token punctuation">,</span> l<span class="token punctuation">,</span> m<span class="token punctuation">)</span><span class="token punctuation">)</span> <span class="token comment">#-------------------------------------------#</span> <span class="token comment"># t 重复5次,使用满足条件的j进行框的提取</span> <span class="token comment"># j 一共五行,代表当前特征点在五个</span> <span class="token comment"># [0, 0], [1, 0], [0, 1], [-1, 0], [0, -1]</span> <span class="token comment"># 方向是否存在</span> <span class="token comment">#-------------------------------------------#</span> t <span class="token operator">=</span> t<span class="token punctuation">.</span>repeat<span class="token punctuation">(</span><span class="token punctuation">(</span><span class="token number">5</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">[</span>j<span class="token punctuation">]</span> offsets <span class="token operator">=</span> <span class="token punctuation">(</span>torch<span class="token punctuation">.</span>zeros_like<span class="token punctuation">(</span>gxy<span class="token punctuation">)</span><span class="token punctuation">[</span><span class="token boolean">None</span><span class="token punctuation">]</span> <span class="token operator">+</span> off<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token boolean">None</span><span class="token punctuation">]</span><span class="token punctuation">)</span><span class="token punctuation">[</span>j<span class="token punctuation">]</span> <span class="token keyword">else</span><span class="token punctuation">:</span> t <span class="token operator">=</span> targets<span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span> offsets <span class="token operator">=</span> <span class="token number">0</span> <span class="token comment">#-------------------------------------------#</span> <span class="token comment"># b 代表属于第几个图片</span> <span class="token comment"># gxy 代表该真实框所处的x、y中心坐标</span> <span class="token comment"># gwh 代表该真实框的wh坐标</span> <span class="token comment"># gij 代表真实框所属的特征点坐标</span> <span class="token comment">#-------------------------------------------#</span> b<span class="token punctuation">,</span> c <span class="token operator">=</span> t<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token punctuation">:</span><span class="token number">2</span><span class="token punctuation">]</span><span class="token punctuation">.</span><span class="token builtin">long</span><span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">.</span>T <span class="token comment"># image, class</span> gxy <span class="token operator">=</span> t<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">:</span><span class="token number">4</span><span class="token punctuation">]</span> <span class="token comment"># grid xy</span> gwh <span class="token operator">=</span> t<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token number">4</span><span class="token punctuation">:</span><span class="token number">6</span><span class="token punctuation">]</span> <span class="token comment"># grid wh</span> gij <span class="token operator">=</span> <span class="token punctuation">(</span>gxy <span class="token operator">-</span> offsets<span class="token punctuation">)</span><span class="token punctuation">.</span><span class="token builtin">long</span><span class="token punctuation">(</span><span class="token punctuation">)</span> gi<span class="token punctuation">,</span> gj <span class="token operator">=</span> gij<span class="token punctuation">.</span>T <span class="token comment"># grid xy indices</span> <span class="token comment">#-------------------------------------------#</span> <span class="token comment"># gj、gi不能超出特征层范围</span> <span class="token comment"># a代表属于该特征点的第几个先验框</span> <span class="token comment">#-------------------------------------------#</span> a <span class="token operator">=</span> t<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token number">6</span><span class="token punctuation">]</span><span class="token punctuation">.</span><span class="token builtin">long</span><span class="token punctuation">(</span><span class="token punctuation">)</span> <span class="token comment"># anchor indices</span> indices<span class="token punctuation">.</span>append<span class="token punctuation">(</span><span class="token punctuation">(</span>b<span class="token punctuation">,</span> a<span class="token punctuation">,</span> gj<span class="token punctuation">.</span>clamp_<span class="token punctuation">(</span><span class="token number">0</span><span class="token punctuation">,</span> gain<span class="token punctuation">[</span><span class="token number">3</span><span class="token punctuation">]</span> <span class="token operator">-</span> <span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">,</span> gi<span class="token punctuation">.</span>clamp_<span class="token punctuation">(</span><span class="token number">0</span><span class="token punctuation">,</span> gain<span class="token punctuation">[</span><span class="token number">2</span><span class="token punctuation">]</span> <span class="token operator">-</span> <span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">)</span> <span class="token comment"># image, anchor, grid indices</span> anchors<span class="token punctuation">.</span>append<span class="token punctuation">(</span>anchors_i<span class="token punctuation">[</span>a<span class="token punctuation">]</span><span class="token punctuation">)</span> <span class="token comment"># anchors</span> <span class="token keyword">return</span> indices<span class="token punctuation">,</span> anchors
b、SimOTA自适应匹配
在YoloV7中,我们会计算一个Cost代价矩阵,代表每个真实框和每个特征点之间的代价关系,Cost代价矩阵由两个部分组成:
1、每个真实框和当前特征点预测框的重合程度;
2、每个真实框和当前特征点预测框的种类预测准确度;
每个真实框和当前特征点预测框的重合程度越高,代表这个特征点已经尝试去拟合该真实框了,因此它的Cost代价就会越小。
每个真实框和当前特征点预测框的种类预测准确度越高,也代表这个特征点已经尝试去拟合该真实框了,因此它的Cost代价就会越小。
Cost代价矩阵的目的是自适应的找到当前特征点应该去拟合的真实框,重合度越高越需要拟合,分类越准越需要拟合,在一定半径内越需要拟合。
在SimOTA中,不同目标设定不同的正样本数量(dynamick),以旷视科技官方回答中的蚂蚁和西瓜为例子,传统的正样本分配方案常常为同一场景下的西瓜和蚂蚁分配同样的正样本数,那要么蚂蚁有很多低质量的正样本,要么西瓜仅仅只有一两个正样本。对于哪个分配方式都是不合适的。
动态的正样本设置的关键在于如何确定k,SimOTA具体的做法是首先计算每个目标Cost最低的10特征点,然后把这十个特征点对应的预测框与真实框的IOU加起来求得最终的k。
因此,SimOTA的过程总结如下:
1、计算每个真实框和当前特征点预测框的重合程度。
2、计算将重合度最高的二十个预测框与真实框的IOU加起来求得每个真实框的k,也就代表每个真实框有k个特征点与之对应。
3、计算每个真实框和当前特征点预测框的种类预测准确度。
4、计算Cost代价矩阵。
5、将Cost最低的k个点作为该真实框的正样本。
def build_targets(self, predictions, targets, imgs): #-------------------------------------------# # 匹配正样本 #-------------------------------------------# indices, anch = self.find_3_positive(predictions, targets)
matching_bs <span class="token operator">=</span> <span class="token punctuation">[</span><span class="token punctuation">[</span><span class="token punctuation">]</span> <span class="token keyword">for</span> _ <span class="token keyword">in</span> predictions<span class="token punctuation">]</span> matching_as <span class="token operator">=</span> <span class="token punctuation">[</span><span class="token punctuation">[</span><span class="token punctuation">]</span> <span class="token keyword">for</span> _ <span class="token keyword">in</span> predictions<span class="token punctuation">]</span> matching_gjs <span class="token operator">=</span> <span class="token punctuation">[</span><span class="token punctuation">[</span><span class="token punctuation">]</span> <span class="token keyword">for</span> _ <span class="token keyword">in</span> predictions<span class="token punctuation">]</span> matching_gis <span class="token operator">=</span> <span class="token punctuation">[</span><span class="token punctuation">[</span><span class="token punctuation">]</span> <span class="token keyword">for</span> _ <span class="token keyword">in</span> predictions<span class="token punctuation">]</span> matching_targets <span class="token operator">=</span> <span class="token punctuation">[</span><span class="token punctuation">[</span><span class="token punctuation">]</span> <span class="token keyword">for</span> _ <span class="token keyword">in</span> predictions<span class="token punctuation">]</span> matching_anchs <span class="token operator">=</span> <span class="token punctuation">[</span><span class="token punctuation">[</span><span class="token punctuation">]</span> <span class="token keyword">for</span> _ <span class="token keyword">in</span> predictions<span class="token punctuation">]</span> <span class="token comment">#-------------------------------------------#</span> <span class="token comment"># 一共三层</span> <span class="token comment">#-------------------------------------------#</span> num_layer <span class="token operator">=</span> <span class="token builtin">len</span><span class="token punctuation">(</span>predictions<span class="token punctuation">)</span> <span class="token comment">#-------------------------------------------#</span> <span class="token comment"># 对batch_size进行循环,进行OTA匹配</span> <span class="token comment"># 在batch_size循环中对layer进行循环</span> <span class="token comment">#-------------------------------------------#</span> <span class="token keyword">for</span> batch_idx <span class="token keyword">in</span> <span class="token builtin">range</span><span class="token punctuation">(</span>predictions<span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">.</span>shape<span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">)</span><span class="token punctuation">:</span> <span class="token comment">#-------------------------------------------#</span> <span class="token comment"># 先判断匹配上的真实框哪些属于该图片</span> <span class="token comment">#-------------------------------------------#</span> b_idx <span class="token operator">=</span> targets<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token number">0</span><span class="token punctuation">]</span><span class="token operator">==</span>batch_idx this_target <span class="token operator">=</span> targets<span class="token punctuation">[</span>b_idx<span class="token punctuation">]</span> <span class="token comment">#-------------------------------------------#</span> <span class="token comment"># 如果没有真实框属于该图片则continue</span> <span class="token comment">#-------------------------------------------#</span> <span class="token keyword">if</span> this_target<span class="token punctuation">.</span>shape<span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span> <span class="token operator">==</span> <span class="token number">0</span><span class="token punctuation">:</span> <span class="token keyword">continue</span> <span class="token comment">#-------------------------------------------#</span> <span class="token comment"># 真实框的坐标进行缩放</span> <span class="token comment">#-------------------------------------------#</span> txywh <span class="token operator">=</span> this_target<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">:</span><span class="token number">6</span><span class="token punctuation">]</span> <span class="token operator">*</span> imgs<span class="token punctuation">[</span>batch_idx<span class="token punctuation">]</span><span class="token punctuation">.</span>shape<span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">]</span> <span class="token comment">#-------------------------------------------#</span> <span class="token comment"># 从中心宽高到左上角右下角</span> <span class="token comment">#-------------------------------------------#</span> txyxy <span class="token operator">=</span> self<span class="token punctuation">.</span>xywh2xyxy<span class="token punctuation">(</span>txywh<span class="token punctuation">)</span> pxyxys <span class="token operator">=</span> <span class="token punctuation">[</span><span class="token punctuation">]</span> p_cls <span class="token operator">=</span> <span class="token punctuation">[</span><span class="token punctuation">]</span> p_obj <span class="token operator">=</span> <span class="token punctuation">[</span><span class="token punctuation">]</span> from_which_layer <span class="token operator">=</span> <span class="token punctuation">[</span><span class="token punctuation">]</span> all_b <span class="token operator">=</span> <span class="token punctuation">[</span><span class="token punctuation">]</span> all_a <span class="token operator">=</span> <span class="token punctuation">[</span><span class="token punctuation">]</span> all_gj <span class="token operator">=</span> <span class="token punctuation">[</span><span class="token punctuation">]</span> all_gi <span class="token operator">=</span> <span class="token punctuation">[</span><span class="token punctuation">]</span> all_anch <span class="token operator">=</span> <span class="token punctuation">[</span><span class="token punctuation">]</span> <span class="token comment">#-------------------------------------------#</span> <span class="token comment"># 对三个layer进行循环</span> <span class="token comment">#-------------------------------------------#</span> <span class="token keyword">for</span> i<span class="token punctuation">,</span> prediction <span class="token keyword">in</span> <span class="token builtin">enumerate</span><span class="token punctuation">(</span>predictions<span class="token punctuation">)</span><span class="token punctuation">:</span> <span class="token comment">#-------------------------------------------#</span> <span class="token comment"># b代表第几张图片 a代表第几个先验框</span> <span class="token comment"># gj代表y轴,gi代表x轴</span> <span class="token comment">#-------------------------------------------#</span> b<span class="token punctuation">,</span> a<span class="token punctuation">,</span> gj<span class="token punctuation">,</span> gi <span class="token operator">=</span> indices<span class="token punctuation">[</span>i<span class="token punctuation">]</span> idx <span class="token operator">=</span> <span class="token punctuation">(</span>b <span class="token operator">==</span> batch_idx<span class="token punctuation">)</span> b<span class="token punctuation">,</span> a<span class="token punctuation">,</span> gj<span class="token punctuation">,</span> gi <span class="token operator">=</span> b<span class="token punctuation">[</span>idx<span class="token punctuation">]</span><span class="token punctuation">,</span> a<span class="token punctuation">[</span>idx<span class="token punctuation">]</span><span class="token punctuation">,</span> gj<span class="token punctuation">[</span>idx<span class="token punctuation">]</span><span class="token punctuation">,</span> gi<span class="token punctuation">[</span>idx<span class="token punctuation">]</span> all_b<span class="token punctuation">.</span>append<span class="token punctuation">(</span>b<span class="token punctuation">)</span> all_a<span class="token punctuation">.</span>append<span class="token punctuation">(</span>a<span class="token punctuation">)</span> all_gj<span class="token punctuation">.</span>append<span class="token punctuation">(</span>gj<span class="token punctuation">)</span> all_gi<span class="token punctuation">.</span>append<span class="token punctuation">(</span>gi<span class="token punctuation">)</span> all_anch<span class="token punctuation">.</span>append<span class="token punctuation">(</span>anch<span class="token punctuation">[</span>i<span class="token punctuation">]</span><span class="token punctuation">[</span>idx<span class="token punctuation">]</span><span class="token punctuation">)</span> from_which_layer<span class="token punctuation">.</span>append<span class="token punctuation">(</span>torch<span class="token punctuation">.</span>ones<span class="token punctuation">(</span>size<span class="token operator">=</span><span class="token punctuation">(</span><span class="token builtin">len</span><span class="token punctuation">(</span>b<span class="token punctuation">)</span><span class="token punctuation">,</span><span class="token punctuation">)</span><span class="token punctuation">)</span> <span class="token operator">*</span> i<span class="token punctuation">)</span> <span class="token comment">#-------------------------------------------#</span> <span class="token comment"># 取出这个真实框对应的预测结果</span> <span class="token comment">#-------------------------------------------#</span> fg_pred <span class="token operator">=</span> prediction<span class="token punctuation">[</span>b<span class="token punctuation">,</span> a<span class="token punctuation">,</span> gj<span class="token punctuation">,</span> gi<span class="token punctuation">]</span> p_obj<span class="token punctuation">.</span>append<span class="token punctuation">(</span>fg_pred<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token number">4</span><span class="token punctuation">:</span><span class="token number">5</span><span class="token punctuation">]</span><span class="token punctuation">)</span> p_cls<span class="token punctuation">.</span>append<span class="token punctuation">(</span>fg_pred<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token number">5</span><span class="token punctuation">:</span><span class="token punctuation">]</span><span class="token punctuation">)</span> <span class="token comment">#-------------------------------------------#</span> <span class="token comment"># 获得网格后,进行解码</span> <span class="token comment">#-------------------------------------------#</span> grid <span class="token operator">=</span> torch<span class="token punctuation">.</span>stack<span class="token punctuation">(</span><span class="token punctuation">[</span>gi<span class="token punctuation">,</span> gj<span class="token punctuation">]</span><span class="token punctuation">,</span> dim<span class="token operator">=</span><span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">.</span>type_as<span class="token punctuation">(</span>fg_pred<span class="token punctuation">)</span> pxy <span class="token operator">=</span> <span class="token punctuation">(</span>fg_pred<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token punctuation">:</span><span class="token number">2</span><span class="token punctuation">]</span><span class="token punctuation">.</span>sigmoid<span class="token punctuation">(</span><span class="token punctuation">)</span> <span class="token operator">*</span> <span class="token number">2.</span> <span class="token operator">-</span> <span class="token number">0.5</span> <span class="token operator">+</span> grid<span class="token punctuation">)</span> <span class="token operator">*</span> self<span class="token punctuation">.</span>stride<span class="token punctuation">[</span>i<span class="token punctuation">]</span> pwh <span class="token operator">=</span> <span class="token punctuation">(</span>fg_pred<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">:</span><span class="token number">4</span><span class="token punctuation">]</span><span class="token punctuation">.</span>sigmoid<span class="token punctuation">(</span><span class="token punctuation">)</span> <span class="token operator">*</span> <span class="token number">2</span><span class="token punctuation">)</span> <span class="token operator">**</span> <span class="token number">2</span> <span class="token operator">*</span> anch<span class="token punctuation">[</span>i<span class="token punctuation">]</span><span class="token punctuation">[</span>idx<span class="token punctuation">]</span> <span class="token operator">*</span> self<span class="token punctuation">.</span>stride<span class="token punctuation">[</span>i<span class="token punctuation">]</span> pxywh <span class="token operator">=</span> torch<span class="token punctuation">.</span>cat<span class="token punctuation">(</span><span class="token punctuation">[</span>pxy<span class="token punctuation">,</span> pwh<span class="token punctuation">]</span><span class="token punctuation">,</span> dim<span class="token operator">=</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">)</span> pxyxy <span class="token operator">=</span> self<span class="token punctuation">.</span>xywh2xyxy<span class="token punctuation">(</span>pxywh<span class="token punctuation">)</span> pxyxys<span class="token punctuation">.</span>append<span class="token punctuation">(</span>pxyxy<span class="token punctuation">)</span> <span class="token comment">#-------------------------------------------#</span> <span class="token comment"># 判断是否存在对应的预测框,不存在则跳过</span> <span class="token comment">#-------------------------------------------#</span> pxyxys <span class="token operator">=</span> torch<span class="token punctuation">.</span>cat<span class="token punctuation">(</span>pxyxys<span class="token punctuation">,</span> dim<span class="token operator">=</span><span class="token number">0</span><span class="token punctuation">)</span> <span class="token keyword">if</span> pxyxys<span class="token punctuation">.</span>shape<span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span> <span class="token operator">==</span> <span class="token number">0</span><span class="token punctuation">:</span> <span class="token keyword">continue</span> <span class="token comment">#-------------------------------------------#</span> <span class="token comment"># 进行堆叠</span> <span class="token comment">#-------------------------------------------#</span> p_obj <span class="token operator">=</span> torch<span class="token punctuation">.</span>cat<span class="token punctuation">(</span>p_obj<span class="token punctuation">,</span> dim<span class="token operator">=</span><span class="token number">0</span><span class="token punctuation">)</span> p_cls <span class="token operator">=</span> torch<span class="token punctuation">.</span>cat<span class="token punctuation">(</span>p_cls<span class="token punctuation">,</span> dim<span class="token operator">=</span><span class="token number">0</span><span class="token punctuation">)</span> from_which_layer <span class="token operator">=</span> torch<span class="token punctuation">.</span>cat<span class="token punctuation">(</span>from_which_layer<span class="token punctuation">,</span> dim<span class="token operator">=</span><span class="token number">0</span><span class="token punctuation">)</span> all_b <span class="token operator">=</span> torch<span class="token punctuation">.</span>cat<span class="token punctuation">(</span>all_b<span class="token punctuation">,</span> dim<span class="token operator">=</span><span class="token number">0</span><span class="token punctuation">)</span> all_a <span class="token operator">=</span> torch<span class="token punctuation">.</span>cat<span class="token punctuation">(</span>all_a<span class="token punctuation">,</span> dim<span class="token operator">=</span><span class="token number">0</span><span class="token punctuation">)</span> all_gj <span class="token operator">=</span> torch<span class="token punctuation">.</span>cat<span class="token punctuation">(</span>all_gj<span class="token punctuation">,</span> dim<span class="token operator">=</span><span class="token number">0</span><span class="token punctuation">)</span> all_gi <span class="token operator">=</span> torch<span class="token punctuation">.</span>cat<span class="token punctuation">(</span>all_gi<span class="token punctuation">,</span> dim<span class="token operator">=</span><span class="token number">0</span><span class="token punctuation">)</span> all_anch <span class="token operator">=</span> torch<span class="token punctuation">.</span>cat<span class="token punctuation">(</span>all_anch<span class="token punctuation">,</span> dim<span class="token operator">=</span><span class="token number">0</span><span class="token punctuation">)</span> <span class="token comment">#-------------------------------------------------------------#</span> <span class="token comment"># 计算当前图片中,真实框与预测框的重合程度</span> <span class="token comment"># iou的范围为0-1,取-log后为0~inf</span> <span class="token comment"># 重合程度越大,取-log后越小</span> <span class="token comment"># 因此,真实框与预测框重合度越大,pair_wise_iou_loss越小</span> <span class="token comment">#-------------------------------------------------------------#</span> pair_wise_iou <span class="token operator">=</span> self<span class="token punctuation">.</span>box_iou<span class="token punctuation">(</span>txyxy<span class="token punctuation">,</span> pxyxys<span class="token punctuation">)</span> pair_wise_iou_loss <span class="token operator">=</span> <span class="token operator">-</span>torch<span class="token punctuation">.</span>log<span class="token punctuation">(</span>pair_wise_iou <span class="token operator">+</span> <span class="token number">1e-8</span><span class="token punctuation">)</span> <span class="token comment">#-------------------------------------------#</span> <span class="token comment"># 最多二十个预测框与真实框的重合程度</span> <span class="token comment"># 然后求和,找到每个真实框对应几个预测框</span> <span class="token comment">#-------------------------------------------#</span> top_k<span class="token punctuation">,</span> _ <span class="token operator">=</span> torch<span class="token punctuation">.</span>topk<span class="token punctuation">(</span>pair_wise_iou<span class="token punctuation">,</span> <span class="token builtin">min</span><span class="token punctuation">(</span><span class="token number">20</span><span class="token punctuation">,</span> pair_wise_iou<span class="token punctuation">.</span>shape<span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">)</span><span class="token punctuation">,</span> dim<span class="token operator">=</span><span class="token number">1</span><span class="token punctuation">)</span> dynamic_ks <span class="token operator">=</span> torch<span class="token punctuation">.</span>clamp<span class="token punctuation">(</span>top_k<span class="token punctuation">.</span><span class="token builtin">sum</span><span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">.</span><span class="token builtin">int</span><span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token builtin">min</span><span class="token operator">=</span><span class="token number">1</span><span class="token punctuation">)</span> <span class="token comment">#-------------------------------------------#</span> <span class="token comment"># gt_cls_per_image 种类的真实信息</span> <span class="token comment">#-------------------------------------------#</span> gt_cls_per_image <span class="token operator">=</span> F<span class="token punctuation">.</span>one_hot<span class="token punctuation">(</span>this_target<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">.</span>to<span class="token punctuation">(</span>torch<span class="token punctuation">.</span>int64<span class="token punctuation">)</span><span class="token punctuation">,</span> self<span class="token punctuation">.</span>num_classes<span class="token punctuation">)</span><span class="token punctuation">.</span><span class="token builtin">float</span><span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">.</span>unsqueeze<span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">.</span>repeat<span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">,</span> pxyxys<span class="token punctuation">.</span>shape<span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span> <span class="token comment">#-------------------------------------------#</span> <span class="token comment"># cls_preds_ 种类置信度的预测信息</span> <span class="token comment"># cls_preds_越接近于1,y越接近于1</span> <span class="token comment"># y / (1 - y)越接近于无穷大</span> <span class="token comment"># 也就是种类置信度预测的越准</span> <span class="token comment"># pair_wise_cls_loss越小</span> <span class="token comment">#-------------------------------------------#</span> num_gt <span class="token operator">=</span> this_target<span class="token punctuation">.</span>shape<span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span> cls_preds_ <span class="token operator">=</span> p_cls<span class="token punctuation">.</span><span class="token builtin">float</span><span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">.</span>unsqueeze<span class="token punctuation">(</span><span class="token number">0</span><span class="token punctuation">)</span><span class="token punctuation">.</