Pytorch搭建yolo3目标检测平台

注意事项

yolov3网络结构图中,特征高宽最小的特征层的通道数量不对,正确的输出特征层shape为[batch_size, 13, 13, 512]。代码是正确的。

学习前言

一起来看看yolo3的Pytorch实现吧,顺便训练一下自己的数据。
在这里插入图片描述

源码下载

https://github.com/bubbliiiing/yolo3-pytorch
喜欢的可以点个star噢。

yolo3实现思路

一、预测部分

1、主题网络darknet53介绍

在这里插入图片描述
PS:该图有一些小问题,宽高最小的特征层在经过Conv2D Block 5L的处理后,它的shape按照代码应该为(batch_size,13,13,512),而非图中的(batch_size,13,13,1024)。

YoloV3所使用的主干特征提取网络为Darknet53,它具有两个重要特点:
1、Darknet53具有一个重要特点是使用了残差网络Residual,Darknet53中的残差卷积就是首先进行一次卷积核大小为3X3、步长为2的卷积,该卷积会压缩输入进来的特征层的宽和高,此时我们可以获得一个特征层,我们将该特征层命名为layer。之后我们再对该特征层进行一次1X1的卷积和一次3X3的卷积,并把这个结果加上layer,此时我们便构成了残差结构。通过不断的1X1卷积和3X3卷积以及残差边的叠加,我们便大幅度的加深了网络。残差网络的特点是容易优化,并且能够通过增加相当的深度来提高准确率。其内部的残差块使用了跳跃连接,缓解了在深度神经网络中增加深度带来的梯度消失问题。

2、Darknet53的每一个卷积部分使用了特有的DarknetConv2D结构,每一次卷积的时候进行l2正则化,完成卷积后进行BatchNormalization标准化与LeakyReLU。普通的ReLU是将所有的负值都设为零,Leaky ReLU则是给所有负值赋予一个非零斜率。以数学的方式我们可以表示为
在这里插入图片描述
实现代码为:

import math
from collections import OrderedDict

import torch
import torch.nn as nn

#---------------------------------------------------------------------#
# 残差结构
# 利用一个1x1卷积下降通道数,然后利用一个3x3卷积提取特征并且上升通道数
# 最后接上一个残差边
#---------------------------------------------------------------------#
class BasicBlock(nn.Module):
def init(self, inplanes, planes):
super(BasicBlock, self).init()
self.conv1 = nn.Conv2d(inplanes, planes[0], kernel_size=1,
stride=1, padding=0, bias=False)
self.bn1 = nn.BatchNorm2d(planes[0])
self.relu1 = nn.LeakyReLU(0.1)

    self<span class="token punctuation">.</span>conv2 <span class="token operator">=</span> nn<span class="token punctuation">.</span>Conv2d<span class="token punctuation">(</span>planes<span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">,</span> planes<span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">,</span> kernel_size<span class="token operator">=</span><span class="token number">3</span><span class="token punctuation">,</span>
                           stride<span class="token operator">=</span><span class="token number">1</span><span class="token punctuation">,</span> padding<span class="token operator">=</span><span class="token number">1</span><span class="token punctuation">,</span> bias<span class="token operator">=</span><span class="token boolean">False</span><span class="token punctuation">)</span>
    self<span class="token punctuation">.</span>bn2 <span class="token operator">=</span> nn<span class="token punctuation">.</span>BatchNorm2d<span class="token punctuation">(</span>planes<span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
    self<span class="token punctuation">.</span>relu2 <span class="token operator">=</span> nn<span class="token punctuation">.</span>LeakyReLU<span class="token punctuation">(</span><span class="token number">0.1</span><span class="token punctuation">)</span>

<span class="token keyword">def</span> <span class="token function">forward</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> x<span class="token punctuation">)</span><span class="token punctuation">:</span>
    residual <span class="token operator">=</span> x

    out <span class="token operator">=</span> self<span class="token punctuation">.</span>conv1<span class="token punctuation">(</span>x<span class="token punctuation">)</span>
    out <span class="token operator">=</span> self<span class="token punctuation">.</span>bn1<span class="token punctuation">(</span>out<span class="token punctuation">)</span>
    out <span class="token operator">=</span> self<span class="token punctuation">.</span>relu1<span class="token punctuation">(</span>out<span class="token punctuation">)</span>

    out <span class="token operator">=</span> self<span class="token punctuation">.</span>conv2<span class="token punctuation">(</span>out<span class="token punctuation">)</span>
    out <span class="token operator">=</span> self<span class="token punctuation">.</span>bn2<span class="token punctuation">(</span>out<span class="token punctuation">)</span>
    out <span class="token operator">=</span> self<span class="token punctuation">.</span>relu2<span class="token punctuation">(</span>out<span class="token punctuation">)</span>

    out <span class="token operator">+=</span> residual
    <span class="token keyword">return</span> out

class DarkNet(nn.Module):
def init(self, layers):
super(DarkNet, self).init()
self.inplanes = 32
# 416,416,3 -> 416,416,32
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(self.inplanes)
self.relu1 = nn.LeakyReLU(0.1)

    <span class="token comment"># 416,416,32 -&gt; 208,208,64</span>
    self<span class="token punctuation">.</span>layer1 <span class="token operator">=</span> self<span class="token punctuation">.</span>_make_layer<span class="token punctuation">(</span><span class="token punctuation">[</span><span class="token number">32</span><span class="token punctuation">,</span> <span class="token number">64</span><span class="token punctuation">]</span><span class="token punctuation">,</span> layers<span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
    <span class="token comment"># 208,208,64 -&gt; 104,104,128</span>
    self<span class="token punctuation">.</span>layer2 <span class="token operator">=</span> self<span class="token punctuation">.</span>_make_layer<span class="token punctuation">(</span><span class="token punctuation">[</span><span class="token number">64</span><span class="token punctuation">,</span> <span class="token number">128</span><span class="token punctuation">]</span><span class="token punctuation">,</span> layers<span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
    <span class="token comment"># 104,104,128 -&gt; 52,52,256</span>
    self<span class="token punctuation">.</span>layer3 <span class="token operator">=</span> self<span class="token punctuation">.</span>_make_layer<span class="token punctuation">(</span><span class="token punctuation">[</span><span class="token number">128</span><span class="token punctuation">,</span> <span class="token number">256</span><span class="token punctuation">]</span><span class="token punctuation">,</span> layers<span class="token punctuation">[</span><span class="token number">2</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
    <span class="token comment"># 52,52,256 -&gt; 26,26,512</span>
    self<span class="token punctuation">.</span>layer4 <span class="token operator">=</span> self<span class="token punctuation">.</span>_make_layer<span class="token punctuation">(</span><span class="token punctuation">[</span><span class="token number">256</span><span class="token punctuation">,</span> <span class="token number">512</span><span class="token punctuation">]</span><span class="token punctuation">,</span> layers<span class="token punctuation">[</span><span class="token number">3</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
    <span class="token comment"># 26,26,512 -&gt; 13,13,1024</span>
    self<span class="token punctuation">.</span>layer5 <span class="token operator">=</span> self<span class="token punctuation">.</span>_make_layer<span class="token punctuation">(</span><span class="token punctuation">[</span><span class="token number">512</span><span class="token punctuation">,</span> <span class="token number">1024</span><span class="token punctuation">]</span><span class="token punctuation">,</span> layers<span class="token punctuation">[</span><span class="token number">4</span><span class="token punctuation">]</span><span class="token punctuation">)</span>

    self<span class="token punctuation">.</span>layers_out_filters <span class="token operator">=</span> <span class="token punctuation">[</span><span class="token number">64</span><span class="token punctuation">,</span> <span class="token number">128</span><span class="token punctuation">,</span> <span class="token number">256</span><span class="token punctuation">,</span> <span class="token number">512</span><span class="token punctuation">,</span> <span class="token number">1024</span><span class="token punctuation">]</span>

    <span class="token comment"># 进行权值初始化</span>
    <span class="token keyword">for</span> m <span class="token keyword">in</span> self<span class="token punctuation">.</span>modules<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">:</span>
        <span class="token keyword">if</span> <span class="token builtin">isinstance</span><span class="token punctuation">(</span>m<span class="token punctuation">,</span> nn<span class="token punctuation">.</span>Conv2d<span class="token punctuation">)</span><span class="token punctuation">:</span>
            n <span class="token operator">=</span> m<span class="token punctuation">.</span>kernel_size<span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span> <span class="token operator">*</span> m<span class="token punctuation">.</span>kernel_size<span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">]</span> <span class="token operator">*</span> m<span class="token punctuation">.</span>out_channels
            m<span class="token punctuation">.</span>weight<span class="token punctuation">.</span>data<span class="token punctuation">.</span>normal_<span class="token punctuation">(</span><span class="token number">0</span><span class="token punctuation">,</span> math<span class="token punctuation">.</span>sqrt<span class="token punctuation">(</span><span class="token number">2</span><span class="token punctuation">.</span> <span class="token operator">/</span> n<span class="token punctuation">)</span><span class="token punctuation">)</span>
        <span class="token keyword">elif</span> <span class="token builtin">isinstance</span><span class="token punctuation">(</span>m<span class="token punctuation">,</span> nn<span class="token punctuation">.</span>BatchNorm2d<span class="token punctuation">)</span><span class="token punctuation">:</span>
            m<span class="token punctuation">.</span>weight<span class="token punctuation">.</span>data<span class="token punctuation">.</span>fill_<span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">)</span>
            m<span class="token punctuation">.</span>bias<span class="token punctuation">.</span>data<span class="token punctuation">.</span>zero_<span class="token punctuation">(</span><span class="token punctuation">)</span>

<span class="token comment">#---------------------------------------------------------------------#</span>
<span class="token comment">#   在每一个layer里面,首先利用一个步长为2的3x3卷积进行下采样</span>
<span class="token comment">#   然后进行残差结构的堆叠</span>
<span class="token comment">#---------------------------------------------------------------------#</span>
<span class="token keyword">def</span> <span class="token function">_make_layer</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> planes<span class="token punctuation">,</span> blocks<span class="token punctuation">)</span><span class="token punctuation">:</span>
    layers <span class="token operator">=</span> <span class="token punctuation">[</span><span class="token punctuation">]</span>
    <span class="token comment"># 下采样,步长为2,卷积核大小为3</span>
    layers<span class="token punctuation">.</span>append<span class="token punctuation">(</span><span class="token punctuation">(</span><span class="token string">"ds_conv"</span><span class="token punctuation">,</span> nn<span class="token punctuation">.</span>Conv2d<span class="token punctuation">(</span>self<span class="token punctuation">.</span>inplanes<span class="token punctuation">,</span> planes<span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">,</span> kernel_size<span class="token operator">=</span><span class="token number">3</span><span class="token punctuation">,</span>
                            stride<span class="token operator">=</span><span class="token number">2</span><span class="token punctuation">,</span> padding<span class="token operator">=</span><span class="token number">1</span><span class="token punctuation">,</span> bias<span class="token operator">=</span><span class="token boolean">False</span><span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">)</span>
    layers<span class="token punctuation">.</span>append<span class="token punctuation">(</span><span class="token punctuation">(</span><span class="token string">"ds_bn"</span><span class="token punctuation">,</span> nn<span class="token punctuation">.</span>BatchNorm2d<span class="token punctuation">(</span>planes<span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">)</span>
    layers<span class="token punctuation">.</span>append<span class="token punctuation">(</span><span class="token punctuation">(</span><span class="token string">"ds_relu"</span><span class="token punctuation">,</span> nn<span class="token punctuation">.</span>LeakyReLU<span class="token punctuation">(</span><span class="token number">0.1</span><span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">)</span>
    <span class="token comment"># 加入残差结构</span>
    self<span class="token punctuation">.</span>inplanes <span class="token operator">=</span> planes<span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">]</span>
    <span class="token keyword">for</span> i <span class="token keyword">in</span> <span class="token builtin">range</span><span class="token punctuation">(</span><span class="token number">0</span><span class="token punctuation">,</span> blocks<span class="token punctuation">)</span><span class="token punctuation">:</span>
        layers<span class="token punctuation">.</span>append<span class="token punctuation">(</span><span class="token punctuation">(</span><span class="token string">"residual_{}"</span><span class="token punctuation">.</span><span class="token builtin">format</span><span class="token punctuation">(</span>i<span class="token punctuation">)</span><span class="token punctuation">,</span> BasicBlock<span class="token punctuation">(</span>self<span class="token punctuation">.</span>inplanes<span class="token punctuation">,</span> planes<span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">)</span>
    <span class="token keyword">return</span> nn<span class="token punctuation">.</span>Sequential<span class="token punctuation">(</span>OrderedDict<span class="token punctuation">(</span>layers<span class="token punctuation">)</span><span class="token punctuation">)</span>

<span class="token keyword">def</span> <span class="token function">forward</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> x<span class="token punctuation">)</span><span class="token punctuation">:</span>
    x <span class="token operator">=</span> self<span class="token punctuation">.</span>conv1<span class="token punctuation">(</span>x<span class="token punctuation">)</span>
    x <span class="token operator">=</span> self<span class="token punctuation">.</span>bn1<span class="token punctuation">(</span>x<span class="token punctuation">)</span>
    x <span class="token operator">=</span> self<span class="token punctuation">.</span>relu1<span class="token punctuation">(</span>x<span class="token punctuation">)</span>

    x <span class="token operator">=</span> self<span class="token punctuation">.</span>layer1<span class="token punctuation">(</span>x<span class="token punctuation">)</span>
    x <span class="token operator">=</span> self<span class="token punctuation">.</span>layer2<span class="token punctuation">(</span>x<span class="token punctuation">)</span>
    out3 <span class="token operator">=</span> self<span class="token punctuation">.</span>layer3<span class="token punctuation">(</span>x<span class="token punctuation">)</span>
    out4 <span class="token operator">=</span> self<span class="token punctuation">.</span>layer4<span class="token punctuation">(</span>out3<span class="token punctuation">)</span>
    out5 <span class="token operator">=</span> self<span class="token punctuation">.</span>layer5<span class="token punctuation">(</span>out4<span class="token punctuation">)</span>

    <span class="token keyword">return</span> out3<span class="token punctuation">,</span> out4<span class="token punctuation">,</span> out5

def darknet53(pretrained, **kwargs):
model = DarkNet([1, 2, 8, 8, 4])
if pretrained:
if isinstance(pretrained, str):
model.load_state_dict(torch.load(pretrained))
else:
raise Exception(“darknet request a pretrained path. got [{}]”.format(pretrained))
return model

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109

2、从特征获取预测结果

在这里插入图片描述
特征获取预测结果的过程可以分为两个部分,分别是:

  • 构建FPN特征金字塔进行加强特征提取
  • 利用Yolo Head对三个有效特征层进行预测
a、构建FPN特征金字塔进行加强特征提取

在特征利用部分,YoloV3提取多特征层进行目标检测,一共提取三个特征层
三个特征层位于主干部分Darknet53的不同位置,分别位于中间层,中下层,底层,三个特征层的shape分别为(52,52,256)、(26,26,512)、(13,13,1024)。

在获得三个有效特征层后,我们利用这三个有效特征层进行FPN层的构建,构建方式为:

  1. 13x13x1024的特征层进行5次卷积处理,处理完后利用YoloHead获得预测结果一部分用于进行上采样UmSampling2d后与26x26x512特征层进行结合,结合特征层的shape为(26,26,768)。
  2. 结合特征层再次进行5次卷积处理,处理完后利用YoloHead获得预测结果一部分用于进行上采样UmSampling2d后与52x52x256特征层进行结合,结合特征层的shape为(52,52,384)。
  3. 结合特征层再次进行5次卷积处理,处理完后利用YoloHead获得预测结果

特征金字塔可以将不同shape的特征层进行特征融合,有利于提取出更好的特征

b、利用Yolo Head获得预测结果

利用FPN特征金字塔,我们可以获得三个加强特征,这三个加强特征的shape分别为(13,13,512)、(26,26,256)、(52,52,128),然后我们利用这三个shape的特征层传入Yolo Head获得预测结果。

Yolo Head本质上是一次3x3卷积加上一次1x1卷积,3x3卷积的作用是特征整合,1x1卷积的作用是调整通道数。

对三个特征层分别进行处理,假设我们预测是的VOC数据集,我们的输出层的shape分别为(13,13,75),(26,26,75),(52,52,75),最后一个维度为75是因为该图是基于voc数据集的,它的类为20种,YoloV3针对每一个特征层的每一个特征点存在3个先验框,所以预测结果的通道数为3x25;
如果使用的是coco训练集,类则为80种,最后的维度应该为255 = 3x85
,三个特征层的shape为(13,13,255),(26,26,255),(52,52,255)

其实际情况就是,输入N张416x416的图片,在经过多层的运算后,会输出三个shape分别为(N,13,13,255),(N,26,26,255),(N,52,52,255)的数据,对应每个图分为13x13、26x26、52x52的网格上3个先验框的位置。

实现代码如下:

from collections import OrderedDict

import torch
import torch.nn as nn

from nets.darknet import darknet53

def conv2d(filter_in, filter_out, kernel_size):
pad = (kernel_size - 1) // 2 if kernel_size else 0
return nn.Sequential(OrderedDict([
(“conv”, nn.Conv2d(filter_in, filter_out, kernel_size=kernel_size, stride=1, padding=pad, bias=False)),
(“bn”, nn.BatchNorm2d(filter_out)),
(“relu”, nn.LeakyReLU(0.1)),
]))

#------------------------------------------------------------------------#
# make_last_layers里面一共有七个卷积,前五个用于提取特征。
# 后两个用于获得yolo网络的预测结果
#------------------------------------------------------------------------#
def make_last_layers(filters_list, in_filters, out_filter):
m = nn.ModuleList([
conv2d(in_filters, filters_list[0], 1),
conv2d(filters_list[0], filters_list[1], 3),
conv2d(filters_list[1], filters_list[0], 1),
conv2d(filters_list[0], filters_list[1], 3),
conv2d(filters_list[1], filters_list[0], 1),
conv2d(filters_list[0], filters_list[1], 3),
nn.Conv2d(filters_list[1], out_filter, kernel_size=1,
stride=1, padding=0, bias=True)
])
return m

class YoloBody(nn.Module):
def init(self, anchor, num_classes):
super(YoloBody, self).init()
#---------------------------------------------------#
# 生成darknet53的主干模型
# 获得三个有效特征层,他们的shape分别是:
# 52,52,256
# 26,26,512
# 13,13,1024
#---------------------------------------------------#
self.backbone = darknet53(None)

    <span class="token comment"># out_filters : [64, 128, 256, 512, 1024]</span>
    out_filters <span class="token operator">=</span> self<span class="token punctuation">.</span>backbone<span class="token punctuation">.</span>layers_out_filters

    <span class="token comment">#------------------------------------------------------------------------#</span>
    <span class="token comment">#   计算yolo_head的输出通道数,对于voc数据集而言</span>
    <span class="token comment">#   final_out_filter0 = final_out_filter1 = final_out_filter2 = 75</span>
    <span class="token comment">#------------------------------------------------------------------------#</span>
    final_out_filter0 <span class="token operator">=</span> <span class="token builtin">len</span><span class="token punctuation">(</span>anchor<span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">)</span> <span class="token operator">*</span> <span class="token punctuation">(</span><span class="token number">5</span> <span class="token operator">+</span> num_classes<span class="token punctuation">)</span>
    self<span class="token punctuation">.</span>last_layer0 <span class="token operator">=</span> make_last_layers<span class="token punctuation">(</span><span class="token punctuation">[</span><span class="token number">512</span><span class="token punctuation">,</span> <span class="token number">1024</span><span class="token punctuation">]</span><span class="token punctuation">,</span> out_filters<span class="token punctuation">[</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">,</span> final_out_filter0<span class="token punctuation">)</span>

    final_out_filter1 <span class="token operator">=</span> <span class="token builtin">len</span><span class="token punctuation">(</span>anchor<span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">)</span> <span class="token operator">*</span> <span class="token punctuation">(</span><span class="token number">5</span> <span class="token operator">+</span> num_classes<span class="token punctuation">)</span>
    self<span class="token punctuation">.</span>last_layer1_conv <span class="token operator">=</span> conv2d<span class="token punctuation">(</span><span class="token number">512</span><span class="token punctuation">,</span> <span class="token number">256</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span>
    self<span class="token punctuation">.</span>last_layer1_upsample <span class="token operator">=</span> nn<span class="token punctuation">.</span>Upsample<span class="token punctuation">(</span>scale_factor<span class="token operator">=</span><span class="token number">2</span><span class="token punctuation">,</span> mode<span class="token operator">=</span><span class="token string">'nearest'</span><span class="token punctuation">)</span>
    self<span class="token punctuation">.</span>last_layer1 <span class="token operator">=</span> make_last_layers<span class="token punctuation">(</span><span class="token punctuation">[</span><span class="token number">256</span><span class="token punctuation">,</span> <span class="token number">512</span><span class="token punctuation">]</span><span class="token punctuation">,</span> out_filters<span class="token punctuation">[</span><span class="token operator">-</span><span class="token number">2</span><span class="token punctuation">]</span> <span class="token operator">+</span> <span class="token number">256</span><span class="token punctuation">,</span> final_out_filter1<span class="token punctuation">)</span>

    final_out_filter2 <span class="token operator">=</span> <span class="token builtin">len</span><span class="token punctuation">(</span>anchor<span class="token punctuation">[</span><span class="token number">2</span><span class="token punctuation">]</span><span class="token punctuation">)</span> <span class="token operator">*</span> <span class="token punctuation">(</span><span class="token number">5</span> <span class="token operator">+</span> num_classes<span class="token punctuation">)</span>
    self<span class="token punctuation">.</span>last_layer2_conv <span class="token operator">=</span> conv2d<span class="token punctuation">(</span><span class="token number">256</span><span class="token punctuation">,</span> <span class="token number">128</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span>
    self<span class="token punctuation">.</span>last_layer2_upsample <span class="token operator">=</span> nn<span class="token punctuation">.</span>Upsample<span class="token punctuation">(</span>scale_factor<span class="token operator">=</span><span class="token number">2</span><span class="token punctuation">,</span> mode<span class="token operator">=</span><span class="token string">'nearest'</span><span class="token punctuation">)</span>
    self<span class="token punctuation">.</span>last_layer2 <span class="token operator">=</span> make_last_layers<span class="token punctuation">(</span><span class="token punctuation">[</span><span class="token number">128</span><span class="token punctuation">,</span> <span class="token number">256</span><span class="token punctuation">]</span><span class="token punctuation">,</span> out_filters<span class="token punctuation">[</span><span class="token operator">-</span><span class="token number">3</span><span class="token punctuation">]</span> <span class="token operator">+</span> <span class="token number">128</span><span class="token punctuation">,</span> final_out_filter2<span class="token punctuation">)</span>


<span class="token keyword">def</span> <span class="token function">forward</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> x<span class="token punctuation">)</span><span class="token punctuation">:</span>
    <span class="token keyword">def</span> <span class="token function">_branch</span><span class="token punctuation">(</span>last_layer<span class="token punctuation">,</span> layer_in<span class="token punctuation">)</span><span class="token punctuation">:</span>
        <span class="token keyword">for</span> i<span class="token punctuation">,</span> e <span class="token keyword">in</span> <span class="token builtin">enumerate</span><span class="token punctuation">(</span>last_layer<span class="token punctuation">)</span><span class="token punctuation">:</span>
            layer_in <span class="token operator">=</span> e<span class="token punctuation">(</span>layer_in<span class="token punctuation">)</span>
            <span class="token keyword">if</span> i <span class="token operator">==</span> <span class="token number">4</span><span class="token punctuation">:</span>
                out_branch <span class="token operator">=</span> layer_in
        <span class="token keyword">return</span> layer_in<span class="token punctuation">,</span> out_branch
    <span class="token comment">#---------------------------------------------------#   </span>
    <span class="token comment">#   获得三个有效特征层,他们的shape分别是:</span>
    <span class="token comment">#   52,52,256;26,26,512;13,13,1024</span>
    <span class="token comment">#---------------------------------------------------#</span>
    x2<span class="token punctuation">,</span> x1<span class="token punctuation">,</span> x0 <span class="token operator">=</span> self<span class="token punctuation">.</span>backbone<span class="token punctuation">(</span>x<span class="token punctuation">)</span>

    <span class="token comment">#---------------------------------------------------#</span>
    <span class="token comment">#   第一个特征层</span>
    <span class="token comment">#   out0 = (batch_size,255,13,13)</span>
    <span class="token comment">#---------------------------------------------------#</span>
    <span class="token comment"># 13,13,1024 -&gt; 13,13,512 -&gt; 13,13,1024 -&gt; 13,13,512 -&gt; 13,13,1024 -&gt; 13,13,512</span>
    out0<span class="token punctuation">,</span> out0_branch <span class="token operator">=</span> _branch<span class="token punctuation">(</span>self<span class="token punctuation">.</span>last_layer0<span class="token punctuation">,</span> x0<span class="token punctuation">)</span>

    <span class="token comment"># 13,13,512 -&gt; 13,13,256 -&gt; 26,26,256</span>
    x1_in <span class="token operator">=</span> self<span class="token punctuation">.</span>last_layer1_conv<span class="token punctuation">(</span>out0_branch<span class="token punctuation">)</span>
    x1_in <span class="token operator">=</span> self<span class="token punctuation">.</span>last_layer1_upsample<span class="token punctuation">(</span>x1_in<span class="token punctuation">)</span>

    <span class="token comment"># 26,26,256 + 26,26,512 -&gt; 26,26,768</span>
    x1_in <span class="token operator">=</span> torch<span class="token punctuation">.</span>cat<span class="token punctuation">(</span><span class="token punctuation">[</span>x1_in<span class="token punctuation">,</span> x1<span class="token punctuation">]</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span>
    <span class="token comment">#---------------------------------------------------#</span>
    <span class="token comment">#   第二个特征层</span>
    <span class="token comment">#   out1 = (batch_size,255,26,26)</span>
    <span class="token comment">#---------------------------------------------------#</span>
    <span class="token comment"># 26,26,768 -&gt; 26,26,256 -&gt; 26,26,512 -&gt; 26,26,256 -&gt; 26,26,512 -&gt; 26,26,256</span>
    out1<span class="token punctuation">,</span> out1_branch <span class="token operator">=</span> _branch<span class="token punctuation">(</span>self<span class="token punctuation">.</span>last_layer1<span class="token punctuation">,</span> x1_in<span class="token punctuation">)</span>

    <span class="token comment"># 26,26,256 -&gt; 26,26,128 -&gt; 52,52,128</span>
    x2_in <span class="token operator">=</span> self<span class="token punctuation">.</span>last_layer2_conv<span class="token punctuation">(</span>out1_branch<span class="token punctuation">)</span>
    x2_in <span class="token operator">=</span> self<span class="token punctuation">.</span>last_layer2_upsample<span class="token punctuation">(</span>x2_in<span class="token punctuation">)</span>

    <span class="token comment"># 52,52,128 + 52,52,256 -&gt; 52,52,384</span>
    x2_in <span class="token operator">=</span> torch<span class="token punctuation">.</span>cat<span class="token punctuation">(</span><span class="token punctuation">[</span>x2_in<span class="token punctuation">,</span> x2<span class="token punctuation">]</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span>
    <span class="token comment">#---------------------------------------------------#</span>
    <span class="token comment">#   第一个特征层</span>
    <span class="token comment">#   out3 = (batch_size,255,52,52)</span>
    <span class="token comment">#---------------------------------------------------#</span>
    <span class="token comment"># 52,52,384 -&gt; 52,52,128 -&gt; 52,52,256 -&gt; 52,52,128 -&gt; 52,52,256 -&gt; 52,52,128</span>
    out2<span class="token punctuation">,</span> _ <span class="token operator">=</span> _branch<span class="token punctuation">(</span>self<span class="token punctuation">.</span>last_layer2<span class="token punctuation">,</span> x2_in<span class="token punctuation">)</span>
    <span class="token keyword">return</span> out0<span class="token punctuation">,</span> out1<span class="token punctuation">,</span> out2
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112

3、预测结果的解码

由第二步我们可以获得三个特征层的预测结果,shape分别为:

  • (N,13,13,255)
  • (N,26,26,255)
  • (N,52,52,255)

在这里我们简单了解一下每个有效特征层到底做了什么:
每一个有效特征层将整个图片分成与其长宽对应的网格如(N,13,13,255)的特征层就是将整个图像分成13x13个网格;然后从每个网格中心建立多个先验框,这些框是网络预先设定好的框,网络的预测结果会判断这些框内是否包含物体,以及这个物体的种类。

由于每一个网格点都具有三个先验框,所以上述的预测结果可以reshape为:

  • (N,13,13,3,85)
  • (N,26,26,3,85)
  • (N,52,52,3,85)

其中的85可以拆分为4+1+80,其中的4代表先验框的调整参数,1代表先验框内是否包含物体,80代表的是这个先验框的种类,由于coco分了80类,所以这里是80。如果YoloV3只检测两类物体,那么这个85就变为了4+1+2 = 7。

即85包含了4+1+80,分别代表x_offset、y_offset、h和w、置信度、分类结果。

但是这个预测结果并不对应着最终的预测框在图片上的位置,还需要解码才可以完成。

YoloV3的解码过程分为两步:

  • 将每个网格点加上它对应的x_offset和y_offset,加完后的结果就是预测框的中心
  • 然后再利用 先验框和h、w结合 计算出预测框的宽高。这样就能得到整个预测框的位置了。

在这里插入图片描述
得到最终的预测结果后还要进行得分排序与非极大抑制筛选

这一部分基本上是所有目标检测通用的部分。其对于每一个类进行判别:
1、取出每一类得分大于self.obj_threshold的框和得分。
2、利用框的位置和得分进行非极大抑制。

实现代码如下

class DecodeBox(nn.Module):
    def __init__(self, anchors, num_classes, img_size):
        super(DecodeBox, self).__init__()
        #-----------------------------------------------------------#
        #   13x13的特征层对应的anchor是[116,90],[156,198],[373,326]
        #   26x26的特征层对应的anchor是[30,61],[62,45],[59,119]
        #   52x52的特征层对应的anchor是[10,13],[16,30],[33,23]
        #-----------------------------------------------------------#
        self.anchors = anchors
        self.num_anchors = len(anchors)
        self.num_classes = num_classes
        self.bbox_attrs = 5 + num_classes
        self.img_size = img_size
<span class="token keyword">def</span> <span class="token function">forward</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> <span class="token builtin">input</span><span class="token punctuation">)</span><span class="token punctuation">:</span>
    <span class="token comment">#-----------------------------------------------#</span>
    <span class="token comment">#   输入的input一共有三个,他们的shape分别是</span>
    <span class="token comment">#   batch_size, 255, 13, 13</span>
    <span class="token comment">#   batch_size, 255, 26, 26</span>
    <span class="token comment">#   batch_size, 255, 52, 52</span>
    <span class="token comment">#-----------------------------------------------#</span>
    batch_size <span class="token operator">=</span> <span class="token builtin">input</span><span class="token punctuation">.</span>size<span class="token punctuation">(</span><span class="token number">0</span><span class="token punctuation">)</span>
    input_height <span class="token operator">=</span> <span class="token builtin">input</span><span class="token punctuation">.</span>size<span class="token punctuation">(</span><span class="token number">2</span><span class="token punctuation">)</span>
    input_width <span class="token operator">=</span> <span class="token builtin">input</span><span class="token punctuation">.</span>size<span class="token punctuation">(</span><span class="token number">3</span><span class="token punctuation">)</span>

    <span class="token comment">#-----------------------------------------------#</span>
    <span class="token comment">#   输入为416x416时</span>
    <span class="token comment">#   stride_h = stride_w = 32、16、8</span>
    <span class="token comment">#-----------------------------------------------#</span>
    stride_h <span class="token operator">=</span> self<span class="token punctuation">.</span>img_size<span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">]</span> <span class="token operator">/</span> input_height
    stride_w <span class="token operator">=</span> self<span class="token punctuation">.</span>img_size<span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span> <span class="token operator">/</span> input_width
    <span class="token comment">#-------------------------------------------------#</span>
    <span class="token comment">#   此时获得的scaled_anchors大小是相对于特征层的</span>
    <span class="token comment">#-------------------------------------------------#</span>
    scaled_anchors <span class="token operator">=</span> <span class="token punctuation">[</span><span class="token punctuation">(</span>anchor_width <span class="token operator">/</span> stride_w<span class="token punctuation">,</span> anchor_height <span class="token operator">/</span> stride_h<span class="token punctuation">)</span> <span class="token keyword">for</span> anchor_width<span class="token punctuation">,</span> anchor_height <span class="token keyword">in</span> self<span class="token punctuation">.</span>anchors<span class="token punctuation">]</span>

    <span class="token comment">#-----------------------------------------------#</span>
    <span class="token comment">#   输入的input一共有三个,他们的shape分别是</span>
    <span class="token comment">#   batch_size, 3, 13, 13, 85</span>
    <span class="token comment">#   batch_size, 3, 26, 26, 85</span>
    <span class="token comment">#   batch_size, 3, 52, 52, 85</span>
    <span class="token comment">#-----------------------------------------------#</span>
    prediction <span class="token operator">=</span> <span class="token builtin">input</span><span class="token punctuation">.</span>view<span class="token punctuation">(</span>batch_size<span class="token punctuation">,</span> self<span class="token punctuation">.</span>num_anchors<span class="token punctuation">,</span>
                            self<span class="token punctuation">.</span>bbox_attrs<span class="token punctuation">,</span> input_height<span class="token punctuation">,</span> input_width<span class="token punctuation">)</span><span class="token punctuation">.</span>permute<span class="token punctuation">(</span><span class="token number">0</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">3</span><span class="token punctuation">,</span> <span class="token number">4</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">)</span><span class="token punctuation">.</span>contiguous<span class="token punctuation">(</span><span class="token punctuation">)</span>

    <span class="token comment"># 先验框的中心位置的调整参数</span>
    x <span class="token operator">=</span> torch<span class="token punctuation">.</span>sigmoid<span class="token punctuation">(</span>prediction<span class="token punctuation">[</span><span class="token punctuation">.</span><span class="token punctuation">.</span><span class="token punctuation">.</span><span class="token punctuation">,</span> <span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">)</span>  
    y <span class="token operator">=</span> torch<span class="token punctuation">.</span>sigmoid<span class="token punctuation">(</span>prediction<span class="token punctuation">[</span><span class="token punctuation">.</span><span class="token punctuation">.</span><span class="token punctuation">.</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
    <span class="token comment"># 先验框的宽高调整参数</span>
    w <span class="token operator">=</span> prediction<span class="token punctuation">[</span><span class="token punctuation">.</span><span class="token punctuation">.</span><span class="token punctuation">.</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">]</span>
    h <span class="token operator">=</span> prediction<span class="token punctuation">[</span><span class="token punctuation">.</span><span class="token punctuation">.</span><span class="token punctuation">.</span><span class="token punctuation">,</span> <span class="token number">3</span><span class="token punctuation">]</span>
    <span class="token comment"># 获得置信度,是否有物体</span>
    conf <span class="token operator">=</span> torch<span class="token punctuation">.</span>sigmoid<span class="token punctuation">(</span>prediction<span class="token punctuation">[</span><span class="token punctuation">.</span><span class="token punctuation">.</span><span class="token punctuation">.</span><span class="token punctuation">,</span> <span class="token number">4</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
    <span class="token comment"># 种类置信度</span>
    pred_cls <span class="token operator">=</span> torch<span class="token punctuation">.</span>sigmoid<span class="token punctuation">(</span>prediction<span class="token punctuation">[</span><span class="token punctuation">.</span><span class="token punctuation">.</span><span class="token punctuation">.</span><span class="token punctuation">,</span> <span class="token number">5</span><span class="token punctuation">:</span><span class="token punctuation">]</span><span class="token punctuation">)</span>

    FloatTensor <span class="token operator">=</span> torch<span class="token punctuation">.</span>cuda<span class="token punctuation">.</span>FloatTensor <span class="token keyword">if</span> x<span class="token punctuation">.</span>is_cuda <span class="token keyword">else</span> torch<span class="token punctuation">.</span>FloatTensor
    LongTensor <span class="token operator">=</span> torch<span class="token punctuation">.</span>cuda<span class="token punctuation">.</span>LongTensor <span class="token keyword">if</span> x<span class="token punctuation">.</span>is_cuda <span class="token keyword">else</span> torch<span class="token punctuation">.</span>LongTensor

    <span class="token comment">#----------------------------------------------------------#</span>
    <span class="token comment">#   生成网格,先验框中心,网格左上角 </span>
    <span class="token comment">#   batch_size,3,13,13</span>
    <span class="token comment">#----------------------------------------------------------#</span>
    grid_x <span class="token operator">=</span> torch<span class="token punctuation">.</span>linspace<span class="token punctuation">(</span><span class="token number">0</span><span class="token punctuation">,</span> input_width <span class="token operator">-</span> <span class="token number">1</span><span class="token punctuation">,</span> input_width<span class="token punctuation">)</span><span class="token punctuation">.</span>repeat<span class="token punctuation">(</span>input_height<span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">.</span>repeat<span class="token punctuation">(</span>
        batch_size <span class="token operator">*</span> self<span class="token punctuation">.</span>num_anchors<span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">.</span>view<span class="token punctuation">(</span>x<span class="token punctuation">.</span>shape<span class="token punctuation">)</span><span class="token punctuation">.</span><span class="token builtin">type</span><span class="token punctuation">(</span>FloatTensor<span class="token punctuation">)</span>
    grid_y <span class="token operator">=</span> torch<span class="token punctuation">.</span>linspace<span class="token punctuation">(</span><span class="token number">0</span><span class="token punctuation">,</span> input_height <span class="token operator">-</span> <span class="token number">1</span><span class="token punctuation">,</span> input_height<span class="token punctuation">)</span><span class="token punctuation">.</span>repeat<span class="token punctuation">(</span>input_width<span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">.</span>t<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">.</span>repeat<span class="token punctuation">(</span>
        batch_size <span class="token operator">*</span> self<span class="token punctuation">.</span>num_anchors<span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">.</span>view<span class="token punctuation">(</span>y<span class="token punctuation">.</span>shape<span class="token punctuation">)</span><span class="token punctuation">.</span><span class="token builtin">type</span><span class="token punctuation">(</span>FloatTensor<span class="token punctuation">)</span>

    <span class="token comment">#----------------------------------------------------------#</span>
    <span class="token comment">#   按照网格格式生成先验框的宽高</span>
    <span class="token comment">#   batch_size,3,13,13</span>
    <span class="token comment">#----------------------------------------------------------#</span>
    anchor_w <span class="token operator">=</span> FloatTensor<span class="token punctuation">(</span>scaled_anchors<span class="token punctuation">)</span><span class="token punctuation">.</span>index_select<span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">,</span> LongTensor<span class="token punctuation">(</span><span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">)</span><span class="token punctuation">)</span>
    anchor_h <span class="token operator">=</span> FloatTensor<span class="token punctuation">(</span>scaled_anchors<span class="token punctuation">)</span><span class="token punctuation">.</span>index_select<span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">,</span> LongTensor<span class="token punctuation">(</span><span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">)</span><span class="token punctuation">)</span>
    anchor_w <span class="token operator">=</span> anchor_w<span class="token punctuation">.</span>repeat<span class="token punctuation">(</span>batch_size<span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">.</span>repeat<span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">,</span> input_height <span class="token operator">*</span> input_width<span class="token punctuation">)</span><span class="token punctuation">.</span>view<span class="token punctuation">(</span>w<span class="token punctuation">.</span>shape<span class="token punctuation">)</span>
    anchor_h <span class="token operator">=</span> anchor_h<span class="token punctuation">.</span>repeat<span class="token punctuation">(</span>batch_size<span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">.</span>repeat<span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">,</span> input_height <span class="token operator">*</span> input_width<span class="token punctuation">)</span><span class="token punctuation">.</span>view<span class="token punctuation">(</span>h<span class="token punctuation">.</span>shape<span class="token punctuation">)</span>

    <span class="token comment">#----------------------------------------------------------#</span>
    <span class="token comment">#   利用预测结果对先验框进行调整</span>
    <span class="token comment">#   首先调整先验框的中心,从先验框中心向右下角偏移</span>
    <span class="token comment">#   再调整先验框的宽高。</span>
    <span class="token comment">#----------------------------------------------------------#</span>
    pred_boxes <span class="token operator">=</span> FloatTensor<span class="token punctuation">(</span>prediction<span class="token punctuation">[</span><span class="token punctuation">.</span><span class="token punctuation">.</span><span class="token punctuation">.</span><span class="token punctuation">,</span> <span class="token punctuation">:</span><span class="token number">4</span><span class="token punctuation">]</span><span class="token punctuation">.</span>shape<span class="token punctuation">)</span>
    pred_boxes<span class="token punctuation">[</span><span class="token punctuation">.</span><span class="token punctuation">.</span><span class="token punctuation">.</span><span class="token punctuation">,</span> <span class="token number">0</span><span class="token punctuation">]</span> <span class="token operator">=</span> x<span class="token punctuation">.</span>data <span class="token operator">+</span> grid_x
    pred_boxes<span class="token punctuation">[</span><span class="token punctuation">.</span><span class="token punctuation">.</span><span class="token punctuation">.</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">]</span> <span class="token operator">=</span> y<span class="token punctuation">.</span>data <span class="token operator">+</span> grid_y
    pred_boxes<span class="token punctuation">[</span><span class="token punctuation">.</span><span class="token punctuation">.</span><span class="token punctuation">.</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">]</span> <span class="token operator">=</span> torch<span class="token punctuation">.</span>exp<span class="token punctuation">(</span>w<span class="token punctuation">.</span>data<span class="token punctuation">)</span> <span class="token operator">*</span> anchor_w
    pred_boxes<span class="token punctuation">[</span><span class="token punctuation">.</span><span class="token punctuation">.</span><span class="token punctuation">.</span><span class="token punctuation">,</span> <span class="token number">3</span><span class="token punctuation">]</span> <span class="token operator">=</span> torch<span class="token punctuation">.</span>exp<span class="token punctuation">(</span>h<span class="token punctuation">.</span>data<span class="token punctuation">)</span> <span class="token operator">*</span> anchor_h

    <span class="token comment">#----------------------------------------------------------#</span>
    <span class="token comment">#   将输出结果调整成相对于输入图像大小</span>
    <span class="token comment">#----------------------------------------------------------#</span>
    _scale <span class="token operator">=</span> torch<span class="token punctuation">.</span>Tensor<span class="token punctuation">(</span><span class="token punctuation">[</span>stride_w<span class="token punctuation">,</span> stride_h<span class="token punctuation">]</span> <span class="token operator">*</span> <span class="token number">2</span><span class="token punctuation">)</span><span class="token punctuation">.</span><span class="token builtin">type</span><span class="token punctuation">(</span>FloatTensor<span class="token punctuation">)</span>
    output <span class="token operator">=</span> torch<span class="token punctuation">.</span>cat<span class="token punctuation">(</span><span class="token punctuation">(</span>pred_boxes<span class="token punctuation">.</span>view<span class="token punctuation">(</span>batch_size<span class="token punctuation">,</span> <span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">4</span><span class="token punctuation">)</span> <span class="token operator">*</span> _scale<span class="token punctuation">,</span>
                        conf<span class="token punctuation">.</span>view<span class="token punctuation">(</span>batch_size<span class="token punctuation">,</span> <span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">,</span> pred_cls<span class="token punctuation">.</span>view<span class="token punctuation">(</span>batch_size<span class="token punctuation">,</span> <span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">,</span> self<span class="token punctuation">.</span>num_classes<span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">)</span>
    <span class="token keyword">return</span> output<span class="token punctuation">.</span>data
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95

4、在原图上进行绘制

通过第三步,我们可以获得预测框在原图上的位置,而且这些预测框都是经过筛选的。这些筛选后的框可以直接绘制在图片上,就可以获得结果了。

二、训练部分

1、计算loss所需参数

在计算loss的时候,实际上是pred和target之间的对比:
pred就是网络的预测结果。
target就是网络的真实框情况。

2、pred是什么

对于yolo3的模型来说,网络最后输出的内容就是三个特征层每个网格点对应的预测框及其种类,即三个特征层分别对应着图片被分为不同size的网格后,每个网格点上三个先验框对应的位置、置信度及其种类。

输出层的shape分别为(13,13,75),(26,26,75),(52,52,75),最后一个维度为75是因为是基于voc数据集的,它的类为20种,yolo3只有针对每一个特征层存在3个先验框,所以最后维度为3x25;
如果使用的是coco训练集,类则为80种,最后的维度应该为255 = 3x85
,三个特征层的shape为(13,13,255),(26,26,255),(52,52,255)

现在的y_pre还是没有解码的,解码了之后才是真实图像上的情况。

3、target是什么。

target就是一个真实图像中,真实框的情况。
第一个维度是batch_size,第二个维度是每一张图片里面真实框的数量,第三个维度内部是真实框的信息,包括位置以及种类。

4、loss的计算过程

拿到pred和target后,不可以简单的减一下作为对比,需要进行如下步骤。

  1. 判断真实框在图片中的位置,判断其属于哪一个网格点去检测。
  2. 判断真实框和哪个先验框重合程度最高。
  3. 计算该网格点应该有怎么样的预测结果才能获得真实框
  4. 对所有真实框进行如上处理。
  5. 获得网络应该有的预测结果,将其与实际的预测结果对比。
import os

import math
import numpy as np
import scipy.signal
import torch
import torch.nn as nn
from matplotlib import pyplot as plt

def jaccard(_box_a, _box_b):
# 计算真实框的左上角和右下角
b1_x1, b1_x2 = _box_a[:, 0] - _box_a[:, 2] / 2, _box_a[:, 0] + _box_a[:, 2] / 2
b1_y1, b1_y2 = _box_a[:, 1] - _box_a[:, 3] / 2, _box_a[:, 1] + _box_a[:, 3] / 2
# 计算先验框的左上角和右下角
b2_x1, b2_x2 = _box_b[:, 0] - _box_b[:, 2] / 2, _box_b[:, 0] + _box_b[:, 2] / 2
b2_y1, b2_y2 = _box_b[:, 1] - _box_b[:, 3] / 2, _box_b[:, 1] + _box_b[:, 3] / 2
box_a = torch.zeros_like(_box_a)
box_b = torch.zeros_like(_box_b)
box_a[:, 0], box_a[:, 1], box_a[:, 2], box_a[:, 3] = b1_x1, b1_y1, b1_x2, b1_y2
box_b[:, 0], box_b[:, 1], box_b[:, 2], box_b[:, 3] = b2_x1, b2_y1, b2_x2, b2_y2
A = box_a.size(0)
B = box_b.size(0)
max_xy = torch.min(box_a[:, 2:].unsqueeze(1).expand(A, B, 2),
box_b[:, 2:].unsqueeze(0).expand(A, B, 2))
min_xy = torch.max(box_a[:, :2].unsqueeze(1).expand(A, B, 2),
box_b[:, :2].unsqueeze(0).expand(A, B, 2))
inter = torch.clamp((max_xy - min_xy), min=0)

inter <span class="token operator">=</span> inter<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token number">0</span><span class="token punctuation">]</span> <span class="token operator">*</span> inter<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">]</span>
<span class="token comment"># 计算先验框和真实框各自的面积</span>
area_a <span class="token operator">=</span> <span class="token punctuation">(</span><span class="token punctuation">(</span>box_a<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">]</span><span class="token operator">-</span>box_a<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">)</span> <span class="token operator">*</span>
          <span class="token punctuation">(</span>box_a<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token number">3</span><span class="token punctuation">]</span><span class="token operator">-</span>box_a<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">.</span>unsqueeze<span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">.</span>expand_as<span class="token punctuation">(</span>inter<span class="token punctuation">)</span>  <span class="token comment"># [A,B]</span>
area_b <span class="token operator">=</span> <span class="token punctuation">(</span><span class="token punctuation">(</span>box_b<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">]</span><span class="token operator">-</span>box_b<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">)</span> <span class="token operator">*</span>
          <span class="token punctuation">(</span>box_b<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token number">3</span><span class="token punctuation">]</span><span class="token operator">-</span>box_b<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">.</span>unsqueeze<span class="token punctuation">(</span><span class="token number">0</span><span class="token punctuation">)</span><span class="token punctuation">.</span>expand_as<span class="token punctuation">(</span>inter<span class="token punctuation">)</span>  <span class="token comment"># [A,B]</span>
<span class="token comment"># 求IOU</span>
union <span class="token operator">=</span> area_a <span class="token operator">+</span> area_b <span class="token operator">-</span> inter
<span class="token keyword">return</span> inter <span class="token operator">/</span> union  <span class="token comment"># [A,B]</span>

def clip_by_tensor(t,t_min,t_max):
t=t.float()

result <span class="token operator">=</span> <span class="token punctuation">(</span>t <span class="token operator">&gt;=</span> t_min<span class="token punctuation">)</span><span class="token punctuation">.</span><span class="token builtin">float</span><span class="token punctuation">(</span><span class="token punctuation">)</span> <span class="token operator">*</span> t <span class="token operator">+</span> <span class="token punctuation">(</span>t <span class="token operator">&lt;</span> t_min<span class="token punctuation">)</span><span class="token punctuation">.</span><span class="token builtin">float</span><span class="token punctuation">(</span><span class="token punctuation">)</span> <span class="token operator">*</span> t_min
result <span class="token operator">=</span> <span class="token punctuation">(</span>result <span class="token operator">&lt;=</span> t_max<span class="token punctuation">)</span><span class="token punctuation">.</span><span class="token builtin">float</span><span class="token punctuation">(</span><span class="token punctuation">)</span> <span class="token operator">*</span> result <span class="token operator">+</span> <span class="token punctuation">(</span>result <span class="token operator">&gt;</span> t_max<span class="token punctuation">)</span><span class="token punctuation">.</span><span class="token builtin">float</span><span class="token punctuation">(</span><span class="token punctuation">)</span> <span class="token operator">*</span> t_max
<span class="token keyword">return</span> result

def MSELoss(pred,target):
return (pred-target)**2

def BCELoss(pred,target):
epsilon = 1e-7
pred = clip_by_tensor(pred, epsilon, 1.0 - epsilon)
output = -target torch.log(pred) - (1.0 - target) torch.log(1.0 - pred)
return output

class YOLOLoss(nn.Module):
def init(self, anchors, num_classes, img_size, cuda, normalize):
super(YOLOLoss, self).init()
#-----------------------------------------------------------#
# 13x13的特征层对应的anchor是[116,90],[156,198],[373,326]
# 26x26的特征层对应的anchor是[30,61],[62,45],[59,119]
# 52x52的特征层对应的anchor是[10,13],[16,30],[33,23]
#-----------------------------------------------------------#
self.anchors = anchors
self.num_anchors = len(anchors)
self.num_classes = num_classes
self.bbox_attrs = 5 + num_classes
#-------------------------------------#
# 获得特征层的宽高
# 13、26、52
#-------------------------------------#
self.feature_length = [img_size[0]//32,img_size[0]//16,img_size[0]//8]
self.img_size = img_size

    self<span class="token punctuation">.</span>ignore_threshold <span class="token operator">=</span> <span class="token number">0.5</span>
    self<span class="token punctuation">.</span>lambda_xy <span class="token operator">=</span> <span class="token number">1.0</span>
    self<span class="token punctuation">.</span>lambda_wh <span class="token operator">=</span> <span class="token number">1.0</span>
    self<span class="token punctuation">.</span>lambda_conf <span class="token operator">=</span> <span class="token number">1.0</span>
    self<span class="token punctuation">.</span>lambda_cls <span class="token operator">=</span> <span class="token number">1.0</span>
    self<span class="token punctuation">.</span>cuda <span class="token operator">=</span> cuda
    self<span class="token punctuation">.</span>normalize <span class="token operator">=</span> normalize

<span class="token keyword">def</span> <span class="token function">forward</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> <span class="token builtin">input</span><span class="token punctuation">,</span> targets<span class="token operator">=</span><span class="token boolean">None</span><span class="token punctuation">)</span><span class="token punctuation">:</span>
    <span class="token comment">#----------------------------------------------------#</span>
    <span class="token comment">#   input的shape为  bs, 3*(5+num_classes), 13, 13</span>
    <span class="token comment">#                   bs, 3*(5+num_classes), 26, 26</span>
    <span class="token comment">#                   bs, 3*(5+num_classes), 52, 52</span>
    <span class="token comment">#----------------------------------------------------#</span>
    
    <span class="token comment">#-----------------------#</span>
    <span class="token comment">#   一共多少张图片</span>
    <span class="token comment">#-----------------------#</span>
    bs <span class="token operator">=</span> <span class="token builtin">input</span><span class="token punctuation">.</span>size<span class="token punctuation">(</span><span class="token number">0</span><span class="token punctuation">)</span>
    <span class="token comment">#-----------------------#</span>
    <span class="token comment">#   特征层的高</span>
    <span class="token comment">#-----------------------#</span>
    in_h <span class="token operator">=</span> <span class="token builtin">input</span><span class="token punctuation">.</span>size<span class="token punctuation">(</span><span class="token number">2</span><span class="token punctuation">)</span>
    <span class="token comment">#-----------------------#</span>
    <span class="token comment">#   特征层的宽</span>
    <span class="token comment">#-----------------------#</span>
    in_w <span class="token operator">=</span> <span class="token builtin">input</span><span class="token punctuation">.</span>size<span class="token punctuation">(</span><span class="token number">3</span><span class="token punctuation">)</span>

    <span class="token comment">#-----------------------------------------------------------------------#</span>
    <span class="token comment">#   计算步长</span>
    <span class="token comment">#   每一个特征点对应原来的图片上多少个像素点</span>
    <span class="token comment">#   如果特征层为13x13的话,一个特征点就对应原来的图片上的32个像素点</span>
    <span class="token comment">#   如果特征层为26x26的话,一个特征点就对应原来的图片上的16个像素点</span>
    <span class="token comment">#   如果特征层为52x52的话,一个特征点就对应原来的图片上的8个像素点</span>
    <span class="token comment">#   stride_h = stride_w = 32、16、8</span>
    <span class="token comment">#-----------------------------------------------------------------------#</span>
    stride_h <span class="token operator">=</span> self<span class="token punctuation">.</span>img_size<span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">]</span> <span class="token operator">/</span> in_h
    stride_w <span class="token operator">=</span> self<span class="token punctuation">.</span>img_size<span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span> <span class="token operator">/</span> in_w

    <span class="token comment">#-------------------------------------------------#</span>
    <span class="token comment">#   此时获得的scaled_anchors大小是相对于特征层的</span>
    <span class="token comment">#-------------------------------------------------#</span>
    scaled_anchors <span class="token operator">=</span> <span class="token punctuation">[</span><span class="token punctuation">(</span>a_w <span class="token operator">/</span> stride_w<span class="token punctuation">,</span> a_h <span class="token operator">/</span> stride_h<span class="token punctuation">)</span> <span class="token keyword">for</span> a_w<span class="token punctuation">,</span> a_h <span class="token keyword">in</span> self<span class="token punctuation">.</span>anchors<span class="token punctuation">]</span>
    
    <span class="token comment">#-----------------------------------------------#</span>
    <span class="token comment">#   输入的input一共有三个,他们的shape分别是</span>
    <span class="token comment">#   batch_size, 3, 13, 13, 5 + num_classes</span>
    <span class="token comment">#   batch_size, 3, 26, 26, 5 + num_classes</span>
    <span class="token comment">#   batch_size, 3, 52, 52, 5 + num_classes</span>
    <span class="token comment">#-----------------------------------------------#</span>
    prediction <span class="token operator">=</span> <span class="token builtin">input</span><span class="token punctuation">.</span>view<span class="token punctuation">(</span>bs<span class="token punctuation">,</span> <span class="token builtin">int</span><span class="token punctuation">(</span>self<span class="token punctuation">.</span>num_anchors<span class="token operator">/</span><span class="token number">3</span><span class="token punctuation">)</span><span class="token punctuation">,</span>
                            self<span class="token punctuation">.</span>bbox_attrs<span class="token punctuation">,</span> in_h<span class="token punctuation">,</span> in_w<span class="token punctuation">)</span><span class="token punctuation">.</span>permute<span class="token punctuation">(</span><span class="token number">0</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">3</span><span class="token punctuation">,</span> <span class="token number">4</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">)</span><span class="token punctuation">.</span>contiguous<span class="token punctuation">(</span><span class="token punctuation">)</span>
    
    <span class="token comment"># 先验框的中心位置的调整参数</span>
    x <span class="token operator">=</span> torch<span class="token punctuation">.</span>sigmoid<span class="token punctuation">(</span>prediction<span class="token punctuation">[</span><span class="token punctuation">.</span><span class="token punctuation">.</span><span class="token punctuation">.</span><span class="token punctuation">,</span> <span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
    y <span class="token operator">=</span> torch<span class="token punctuation">.</span>sigmoid<span class="token punctuation">(</span>prediction<span class="token punctuation">[</span><span class="token punctuation">.</span><span class="token punctuation">.</span><span class="token punctuation">.</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
    <span class="token comment"># 先验框的宽高调整参数</span>
    w <span class="token operator">=</span> prediction<span class="token punctuation">[</span><span class="token punctuation">.</span><span class="token punctuation">.</span><span class="token punctuation">.</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">]</span>
    h <span class="token operator">=</span> prediction<span class="token punctuation">[</span><span class="token punctuation">.</span><span class="token punctuation">.</span><span class="token punctuation">.</span><span class="token punctuation">,</span> <span class="token number">3</span><span class="token punctuation">]</span>
    <span class="token comment"># 获得置信度,是否有物体</span>
    conf <span class="token operator">=</span> torch<span class="token punctuation">.</span>sigmoid<span class="token punctuation">(</span>prediction<span class="token punctuation">[</span><span class="token punctuation">.</span><span class="token punctuation">.</span><span class="token punctuation">.</span><span class="token punctuation">,</span> <span class="token number">4</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
    <span class="token comment"># 种类置信度</span>
    pred_cls <span class="token operator">=</span> torch<span class="token punctuation">.</span>sigmoid<span class="token punctuation">(</span>prediction<span class="token punctuation">[</span><span class="token punctuation">.</span><span class="token punctuation">.</span><span class="token punctuation">.</span><span class="token punctuation">,</span> <span class="token number">5</span><span class="token punctuation">:</span><span class="token punctuation">]</span><span class="token punctuation">)</span>

    <span class="token comment">#---------------------------------------------------------------#</span>
    <span class="token comment">#   找到哪些先验框内部包含物体</span>
    <span class="token comment">#   利用真实框和先验框计算交并比</span>
    <span class="token comment">#   mask        batch_size, 3, in_h, in_w   无目标的特征点</span>
    <span class="token comment">#   noobj_mask  batch_size, 3, in_h, in_w   有目标的特征点</span>
    <span class="token comment">#   tx          batch_size, 3, in_h, in_w   中心x偏移情况</span>
    <span class="token comment">#   ty          batch_size, 3, in_h, in_w   中心y偏移情况</span>
    <span class="token comment">#   tw          batch_size, 3, in_h, in_w   宽高调整参数的真实值</span>
    <span class="token comment">#   th          batch_size, 3, in_h, in_w   宽高调整参数的真实值</span>
    <span class="token comment">#   tconf       batch_size, 3, in_h, in_w   置信度真实值</span>
    <span class="token comment">#   tcls        batch_size, 3, in_h, in_w, num_classes  种类真实值</span>
    <span class="token comment">#----------------------------------------------------------------#</span>
    mask<span class="token punctuation">,</span> noobj_mask<span class="token punctuation">,</span> tx<span class="token punctuation">,</span> ty<span class="token punctuation">,</span> tw<span class="token punctuation">,</span> th<span class="token punctuation">,</span> tconf<span class="token punctuation">,</span> tcls<span class="token punctuation">,</span> box_loss_scale_x<span class="token punctuation">,</span> box_loss_scale_y <span class="token operator">=</span>\
                                                                        self<span class="token punctuation">.</span>get_target<span class="token punctuation">(</span>targets<span class="token punctuation">,</span> scaled_anchors<span class="token punctuation">,</span>
                                                                                        in_w<span class="token punctuation">,</span> in_h<span class="token punctuation">,</span>
                                                                                        self<span class="token punctuation">.</span>ignore_threshold<span class="token punctuation">)</span>

    <span class="token comment">#---------------------------------------------------------------#</span>
    <span class="token comment">#   将预测结果进行解码,判断预测结果和真实值的重合程度</span>
    <span class="token comment">#   如果重合程度过大则忽略,因为这些特征点属于预测比较准确的特征点</span>
    <span class="token comment">#   作为负样本不合适</span>
    <span class="token comment">#----------------------------------------------------------------#</span>
    noobj_mask <span class="token operator">=</span> self<span class="token punctuation">.</span>get_ignore<span class="token punctuation">(</span>prediction<span class="token punctuation">,</span> targets<span class="token punctuation">,</span> scaled_anchors<span class="token punctuation">,</span> in_w<span class="token punctuation">,</span> in_h<span class="token punctuation">,</span> noobj_mask<span class="token punctuation">)</span>

    <span class="token keyword">if</span> self<span class="token punctuation">.</span>cuda<span class="token punctuation">:</span>
        box_loss_scale_x <span class="token operator">=</span> <span class="token punctuation">(</span>box_loss_scale_x<span class="token punctuation">)</span><span class="token punctuation">.</span>cuda<span class="token punctuation">(</span><span class="token punctuation">)</span>
        box_loss_scale_y <span class="token operator">=</span> <span class="token punctuation">(</span>box_loss_scale_y<span class="token punctuation">)</span><span class="token punctuation">.</span>cuda<span class="token punctuation">(</span><span class="token punctuation">)</span>
        mask<span class="token punctuation">,</span> noobj_mask <span class="token operator">=</span> mask<span class="token punctuation">.</span>cuda<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">,</span> noobj_mask<span class="token punctuation">.</span>cuda<span class="token punctuation">(</span><span class="token punctuation">)</span>
        tx<span class="token punctuation">,</span> ty<span class="token punctuation">,</span> tw<span class="token punctuation">,</span> th <span class="token operator">=</span> tx<span class="token punctuation">.</span>cuda<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">,</span> ty<span class="token punctuation">.</span>cuda<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">,</span> tw<span class="token punctuation">.</span>cuda<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">,</span> th<span class="token punctuation">.</span>cuda<span class="token punctuation">(</span><span class="token punctuation">)</span>
        tconf<span class="token punctuation">,</span> tcls <span class="token operator">=</span> tconf<span class="token punctuation">.</span>cuda<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">,</span> tcls<span class="token punctuation">.</span>cuda<span class="token punctuation">(</span><span class="token punctuation">)</span>
    box_loss_scale <span class="token operator">=</span> <span class="token number">2</span> <span class="token operator">-</span> box_loss_scale_x <span class="token operator">*</span> box_loss_scale_y
    
    <span class="token comment"># 计算中心偏移情况的loss,使用BCELoss效果好一些</span>
    loss_x <span class="token operator">=</span> torch<span class="token punctuation">.</span><span class="token builtin">sum</span><span class="token punctuation">(</span>BCELoss<span class="token punctuation">(</span>x<span class="token punctuation">,</span> tx<span class="token punctuation">)</span> <span class="token operator">*</span> box_loss_scale <span class="token operator">*</span> mask<span class="token punctuation">)</span>
    loss_y <span class="token operator">=</span> torch<span class="token punctuation">.</span><span class="token builtin">sum</span><span class="token punctuation">(</span>BCELoss<span class="token punctuation">(</span>y<span class="token punctuation">,</span> ty<span class="token punctuation">)</span> <span class="token operator">*</span> box_loss_scale <span class="token operator">*</span> mask<span class="token punctuation">)</span>
    <span class="token comment"># 计算宽高调整值的loss</span>
    loss_w <span class="token operator">=</span> torch<span class="token punctuation">.</span><span class="token builtin">sum</span><span class="token punctuation">(</span>MSELoss<span class="token punctuation">(</span>w<span class="token punctuation">,</span> tw<span class="token punctuation">)</span> <span class="token operator">*</span> <span class="token number">0.5</span> <span class="token operator">*</span> box_loss_scale <span class="token operator">*</span> mask<span class="token punctuation">)</span>
    loss_h <span class="token operator">=</span> torch<span class="token punctuation">.</span><span class="token builtin">sum</span><span class="token punctuation">(</span>MSELoss<span class="token punctuation">(</span>h<span class="token punctuation">,</span> th<span class="token punctuation">)</span> <span class="token operator">*</span> <span class="token number">0.5</span> <span class="token operator">*</span> box_loss_scale <span class="token operator">*</span> mask<span class="token punctuation">)</span>
    <span class="token comment"># 计算置信度的loss</span>
    loss_conf <span class="token operator">=</span> torch<span class="token punctuation">.</span><span class="token builtin">sum</span><span class="token punctuation">(</span>BCELoss<span class="token punctuation">(</span>conf<span class="token punctuation">,</span> mask<span class="token punctuation">)</span> <span class="token operator">*</span> mask<span class="token punctuation">)</span> <span class="token operator">+</span> \
                torch<span class="token punctuation">.</span><span class="token builtin">sum</span><span class="token punctuation">(</span>BCELoss<span class="token punctuation">(</span>conf<span class="token punctuation">,</span> mask<span class="token punctuation">)</span> <span class="token operator">*</span> noobj_mask<span class="token punctuation">)</span>
                
    loss_cls <span class="token operator">=</span> torch<span class="token punctuation">.</span><span class="token builtin">sum</span><span class="token punctuation">(</span>BCELoss<span class="token punctuation">(</span>pred_cls<span class="token punctuation">[</span>mask <span class="token operator">==</span> <span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">,</span> tcls<span class="token punctuation">[</span>mask <span class="token operator">==</span> <span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">)</span><span class="token punctuation">)</span>

    loss <span class="token operator">=</span> loss_x <span class="token operator">*</span> self<span class="token punctuation">.</span>lambda_xy <span class="token operator">+</span> loss_y <span class="token operator">*</span> self<span class="token punctuation">.</span>lambda_xy <span class="token operator">+</span> \
            loss_w <span class="token operator">*</span> self<span class="token punctuation">.</span>lambda_wh <span class="token operator">+</span> loss_h <span class="token operator">*</span> self<span class="token punctuation">.</span>lambda_wh <span class="token operator">+</span> \
            loss_conf <span class="token operator">*</span> self<span class="token punctuation">.</span>lambda_conf <span class="token operator">+</span> loss_cls <span class="token operator">*</span> self<span class="token punctuation">.</span>lambda_cls

    <span class="token comment"># print(loss, loss_x.item() + loss_y.item(), loss_w.item() + loss_h.item(), </span>
    <span class="token comment">#         loss_conf.item(), loss_cls.item(), \</span>
    <span class="token comment">#         torch.sum(mask),torch.sum(noobj_mask))</span>
    <span class="token keyword">if</span> self<span class="token punctuation">.</span>normalize<span class="token punctuation">:</span>
        num_pos <span class="token operator">=</span> torch<span class="token punctuation">.</span><span class="token builtin">sum</span><span class="token punctuation">(</span>mask<span class="token punctuation">)</span>
        num_pos <span class="token operator">=</span> torch<span class="token punctuation">.</span><span class="token builtin">max</span><span class="token punctuation">(</span>num_pos<span class="token punctuation">,</span> torch<span class="token punctuation">.</span>ones_like<span class="token punctuation">(</span>num_pos<span class="token punctuation">)</span><span class="token punctuation">)</span>
    <span class="token keyword">else</span><span class="token punctuation">:</span>
        num_pos <span class="token operator">=</span> bs<span class="token operator">/</span><span class="token number">3</span>
    <span class="token keyword">return</span> loss<span class="token punctuation">,</span> num_pos

<span class="token keyword">def</span> <span class="token function">get_target</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> target<span class="token punctuation">,</span> anchors<span class="token punctuation">,</span> in_w<span class="token punctuation">,</span> in_h<span class="token punctuation">,</span> ignore_threshold<span class="token punctuation">)</span><span class="token punctuation">:</span>
    <span class="token comment">#-----------------------------------------------------#</span>
    <span class="token comment">#   计算一共有多少张图片</span>
    <span class="token comment">#-----------------------------------------------------#</span>
    bs <span class="token operator">=</span> <span class="token builtin">len</span><span class="token punctuation">(</span>target<span class="token punctuation">)</span>
    <span class="token comment">#-------------------------------------------------------#</span>
    <span class="token comment">#   获得当前特征层先验框所属的编号,方便后面对先验框筛选</span>
    <span class="token comment">#-------------------------------------------------------#</span>
    anchor_index <span class="token operator">=</span> <span class="token punctuation">[</span><span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">,</span><span class="token number">1</span><span class="token punctuation">,</span><span class="token number">2</span><span class="token punctuation">]</span><span class="token punctuation">,</span><span class="token punctuation">[</span><span class="token number">3</span><span class="token punctuation">,</span><span class="token number">4</span><span class="token punctuation">,</span><span class="token number">5</span><span class="token punctuation">]</span><span class="token punctuation">,</span><span class="token punctuation">[</span><span class="token number">6</span><span class="token punctuation">,</span><span class="token number">7</span><span class="token punctuation">,</span><span class="token number">8</span><span class="token punctuation">]</span><span class="token punctuation">]</span><span class="token punctuation">[</span>self<span class="token punctuation">.</span>feature_length<span class="token punctuation">.</span>index<span class="token punctuation">(</span>in_w<span class="token punctuation">)</span><span class="token punctuation">]</span>
    subtract_index <span class="token operator">=</span> <span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">,</span><span class="token number">3</span><span class="token punctuation">,</span><span class="token number">6</span><span class="token punctuation">]</span><span class="token punctuation">[</span>self<span class="token punctuation">.</span>feature_length<span class="token punctuation">.</span>index<span class="token punctuation">(</span>in_w<span class="token punctuation">)</span><span class="token punctuation">]</span>
    <span class="token comment">#-------------------------------------------------------#</span>
    <span class="token comment">#   创建全是0或者全是1的阵列</span>
    <span class="token comment">#-------------------------------------------------------#</span>
    mask <span class="token operator">=</span> torch<span class="token punctuation">.</span>zeros<span class="token punctuation">(</span>bs<span class="token punctuation">,</span> <span class="token builtin">int</span><span class="token punctuation">(</span>self<span class="token punctuation">.</span>num_anchors<span class="token operator">/</span><span class="token number">3</span><span class="token punctuation">)</span><span class="token punctuation">,</span> in_h<span class="token punctuation">,</span> in_w<span class="token punctuation">,</span> requires_grad<span class="token operator">=</span><span class="token boolean">False</span><span class="token punctuation">)</span>
    noobj_mask <span class="token operator">=</span> torch<span class="token punctuation">.</span>ones<span class="token punctuation">(</span>bs<span class="token punctuation">,</span> <span class="token builtin">int</span><span class="token punctuation">(</span>self<span class="token punctuation">.</span>num_anchors<span class="token operator">/</span><span class="token number">3</span><span class="token punctuation">)</span><span class="token punctuation">,</span> in_h<span class="token punctuation">,</span> in_w<span class="token punctuation">,</span> requires_grad<span class="token operator">=</span><span class="token boolean">False</span><span class="token punctuation">)</span>

    tx <span class="token operator">=</span> torch<span class="token punctuation">.</span>zeros<span class="token punctuation">(</span>bs<span class="token punctuation">,</span> <span class="token builtin">int</span><span class="token punctuation">(</span>self<span class="token punctuation">.</span>num_anchors<span class="token operator">/</span><span class="token number">3</span><span class="token punctuation">)</span><span class="token punctuation">,</span> in_h<span class="token punctuation">,</span> in_w<span class="token punctuation">,</span> requires_grad<span class="token operator">=</span><span class="token boolean">False</span><span class="token punctuation">)</span>
    ty <span class="token operator">=</span> torch<span class="token punctuation">.</span>zeros<span class="token punctuation">(</span>bs<span class="token punctuation">,</span> <span class="token builtin">int</span><span class="token punctuation">(</span>self<span class="token punctuation">.</span>num_anchors<span class="token operator">/</span><span class="token number">3</span><span class="token punctuation">)</span><span class="token punctuation">,</span> in_h<span class="token punctuation">,</span> in_w<span class="token punctuation">,</span> requires_grad<span class="token operator">=</span><span class="token boolean">False</span><span class="token punctuation">)</span>
    tw <span class="token operator">=</span> torch<span class="token punctuation">.</span>zeros<span class="token punctuation">(</span>bs<span class="token punctuation">,</span> <span class="token builtin">int</span><span class="token punctuation">(</span>self<span class="token punctuation">.</span>num_anchors<span class="token operator">/</span><span class="token number">3</span><span class="token punctuation">)</span><span class="token punctuation">,</span> in_h<span class="token punctuation">,</span> in_w<span class="token punctuation">,</span> requires_grad<span class="token operator">=</span><span class="token boolean">False</span><span class="token punctuation">)</span>
    th <span class="token operator">=</span> torch<span class="token punctuation">.</span>zeros<span class="token punctuation">(</span>bs<span class="token punctuation">,</span> <span class="token builtin">int</span><span class="token punctuation">(</span>self<span class="token punctuation">.</span>num_anchors<span class="token operator">/</span><span class="token number">3</span><span class="token punctuation">)</span><span class="token punctuation">,</span> in_h<span class="token punctuation">,</span> in_w<span class="token punctuation">,</span> requires_grad<span class="token operator">=</span><span class="token boolean">False</span><span class="token punctuation">)</span>
    tconf <span class="token operator">=</span> torch<span class="token punctuation">.</span>zeros<span class="token punctuation">(</span>bs<span class="token punctuation">,</span> <span class="token builtin">int</span><span class="token punctuation">(</span>self<span class="token punctuation">.</span>num_anchors<span class="token operator">/</span><span class="token number">3</span><span class="token punctuation">)</span><span class="token punctuation">,</span> in_h<span class="token punctuation">,</span> in_w<span class="token punctuation">,</span> requires_grad<span class="token operator">=</span><span class="token boolean">False</span><span class="token punctuation">)</span>
    tcls <span class="token operator">=</span> torch<span class="token punctuation">.</span>zeros<span class="token punctuation">(</span>bs<span class="token punctuation">,</span> <span class="token builtin">int</span><span class="token punctuation">(</span>self<span class="token punctuation">.</span>num_anchors<span class="token operator">/</span><span class="token number">3</span><span class="token punctuation">)</span><span class="token punctuation">,</span> in_h<span class="token punctuation">,</span> in_w<span class="token punctuation">,</span> self<span class="token punctuation">.</span>num_classes<span class="token punctuation">,</span> requires_grad<span class="token operator">=</span><span class="token boolean">False</span><span class="token punctuation">)</span>

    box_loss_scale_x <span class="token operator">=</span> torch<span class="token punctuation">.</span>zeros<span class="token punctuation">(</span>bs<span class="token punctuation">,</span> <span class="token builtin">int</span><span class="token punctuation">(</span>self<span class="token punctuation">.</span>num_anchors<span class="token operator">/</span><span class="token number">3</span><span class="token punctuation">)</span><span class="token punctuation">,</span> in_h<span class="token punctuation">,</span> in_w<span class="token punctuation">,</span> requires_grad<span class="token operator">=</span><span class="token boolean">False</span><span class="token punctuation">)</span>
    box_loss_scale_y <span class="token operator">=</span> torch<span class="token punctuation">.</span>zeros<span class="token punctuation">(</span>bs<span class="token punctuation">,</span> <span class="token builtin">int</span><span class="token punctuation">(</span>self<span class="token punctuation">.</span>num_anchors<span class="token operator">/</span><span class="token number">3</span><span class="token punctuation">)</span><span class="token punctuation">,</span> in_h<span class="token punctuation">,</span> in_w<span class="token punctuation">,</span> requires_grad<span class="token operator">=</span><span class="token boolean">False</span><span class="token punctuation">)</span>
    <span class="token keyword">for</span> b <span class="token keyword">in</span> <span class="token builtin">range</span><span class="token punctuation">(</span>bs<span class="token punctuation">)</span><span class="token punctuation">:</span>            
        <span class="token keyword">if</span> <span class="token builtin">len</span><span class="token punctuation">(</span>target<span class="token punctuation">[</span>b<span class="token punctuation">]</span><span class="token punctuation">)</span><span class="token operator">==</span><span class="token number">0</span><span class="token punctuation">:</span>
            <span class="token keyword">continue</span>
        <span class="token comment">#-------------------------------------------------------#</span>
        <span class="token comment">#   计算出正样本在特征层上的中心点</span>
        <span class="token comment">#-------------------------------------------------------#</span>
        gxs <span class="token operator">=</span> target<span class="token punctuation">[</span>b<span class="token punctuation">]</span><span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token number">0</span><span class="token punctuation">:</span><span class="token number">1</span><span class="token punctuation">]</span> <span class="token operator">*</span> in_w
        gys <span class="token operator">=</span> target<span class="token punctuation">[</span>b<span class="token punctuation">]</span><span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">:</span><span class="token number">2</span><span class="token punctuation">]</span> <span class="token operator">*</span> in_h
        
        <span class="token comment">#-------------------------------------------------------#</span>
        <span class="token comment">#   计算出正样本相对于特征层的宽高</span>
        <span class="token comment">#-------------------------------------------------------#</span>
        gws <span class="token operator">=</span> target<span class="token punctuation">[</span>b<span class="token punctuation">]</span><span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">:</span><span class="token number">3</span><span class="token punctuation">]</span> <span class="token operator">*</span> in_w
        ghs <span class="token operator">=</span> target<span class="token punctuation">[</span>b<span class="token punctuation">]</span><span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token number">3</span><span class="token punctuation">:</span><span class="token number">4</span><span class="token punctuation">]</span> <span class="token operator">*</span> in_h

        <span class="token comment">#-------------------------------------------------------#</span>
        <span class="token comment">#   计算出正样本属于特征层的哪个特征点</span>
        <span class="token comment">#-------------------------------------------------------#</span>
        gis <span class="token operator">=</span> torch<span class="token punctuation">.</span>floor<span class="token punctuation">(</span>gxs<span class="token punctuation">)</span>
        gjs <span class="token operator">=</span> torch<span class="token punctuation">.</span>floor<span class="token punctuation">(</span>gys<span class="token punctuation">)</span>
        
        <span class="token comment">#-------------------------------------------------------#</span>
        <span class="token comment">#   将真实框转换一个形式</span>
        <span class="token comment">#   num_true_box, 4</span>
        <span class="token comment">#-------------------------------------------------------#</span>
        gt_box <span class="token operator">=</span> torch<span class="token punctuation">.</span>FloatTensor<span class="token punctuation">(</span>torch<span class="token punctuation">.</span>cat<span class="token punctuation">(</span><span class="token punctuation">[</span>torch<span class="token punctuation">.</span>zeros_like<span class="token punctuation">(</span>gws<span class="token punctuation">)</span><span class="token punctuation">,</span> torch<span class="token punctuation">.</span>zeros_like<span class="token punctuation">(</span>ghs<span class="token punctuation">)</span><span class="token punctuation">,</span> gws<span class="token punctuation">,</span> ghs<span class="token punctuation">]</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">)</span>
        
        <span class="token comment">#-------------------------------------------------------#</span>
        <span class="token comment">#   将先验框转换一个形式</span>
        <span class="token comment">#   9, 4</span>
        <span class="token comment">#-------------------------------------------------------#</span>
        anchor_shapes <span class="token operator">=</span> torch<span class="token punctuation">.</span>FloatTensor<span class="token punctuation">(</span>torch<span class="token punctuation">.</span>cat<span class="token punctuation">(</span><span class="token punctuation">(</span>torch<span class="token punctuation">.</span>zeros<span class="token punctuation">(</span><span class="token punctuation">(</span>self<span class="token punctuation">.</span>num_anchors<span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">,</span> torch<span class="token punctuation">.</span>FloatTensor<span class="token punctuation">(</span>anchors<span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">)</span>
        <span class="token comment">#-------------------------------------------------------#</span>
        <span class="token comment">#   计算交并比</span>
        <span class="token comment">#   num_true_box, 9</span>
        <span class="token comment">#-------------------------------------------------------#</span>
        anch_ious <span class="token operator">=</span> jaccard<span class="token punctuation">(</span>gt_box<span class="token punctuation">,</span> anchor_shapes<span class="token punctuation">)</span>

        <span class="token comment">#-------------------------------------------------------#</span>
        <span class="token comment">#   计算重合度最大的先验框是哪个</span>
        <span class="token comment">#   num_true_box, </span>
        <span class="token comment">#-------------------------------------------------------#</span>
        best_ns <span class="token operator">=</span> torch<span class="token punctuation">.</span>argmax<span class="token punctuation">(</span>anch_ious<span class="token punctuation">,</span>dim<span class="token operator">=</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">)</span>
        <span class="token keyword">for</span> i<span class="token punctuation">,</span> best_n <span class="token keyword">in</span> <span class="token builtin">enumerate</span><span class="token punctuation">(</span>best_ns<span class="token punctuation">)</span><span class="token punctuation">:</span>
            <span class="token keyword">if</span> best_n <span class="token operator">not</span> <span class="token keyword">in</span> anchor_index<span class="token punctuation">:</span>
                <span class="token keyword">continue</span>
            <span class="token comment">#-------------------------------------------------------------#</span>
            <span class="token comment">#   取出各类坐标:</span>
            <span class="token comment">#   gi和gj代表的是真实框对应的特征点的x轴y轴坐标</span>
            <span class="token comment">#   gx和gy代表真实框的x轴和y轴坐标</span>
            <span class="token comment">#   gw和gh代表真实框的宽和高</span>
            <span class="token comment">#-------------------------------------------------------------#</span>
            gi <span class="token operator">=</span> gis<span class="token punctuation">[</span>i<span class="token punctuation">]</span><span class="token punctuation">.</span><span class="token builtin">long</span><span class="token punctuation">(</span><span class="token punctuation">)</span>
            gj <span class="token operator">=</span> gjs<span class="token punctuation">[</span>i<span class="token punctuation">]</span><span class="token punctuation">.</span><span class="token builtin">long</span><span class="token punctuation">(</span><span class="token punctuation">)</span>
            gx <span class="token operator">=</span> gxs<span class="token punctuation">[</span>i<span class="token punctuation">]</span>
            gy <span class="token operator">=</span> gys<span class="token punctuation">[</span>i<span class="token punctuation">]</span>
            gw <span class="token operator">=</span> gws<span class="token punctuation">[</span>i<span class="token punctuation">]</span>
            gh <span class="token operator">=</span> ghs<span class="token punctuation">[</span>i<span class="token punctuation">]</span>

            <span class="token keyword">if</span> <span class="token punctuation">(</span>gj <span class="token operator">&lt;</span> in_h<span class="token punctuation">)</span> <span class="token operator">and</span> <span class="token punctuation">(</span>gi <span class="token operator">&lt;</span> in_w<span class="token punctuation">)</span><span class="token punctuation">:</span>
                best_n <span class="token operator">=</span> best_n <span class="token operator">-</span> subtract_index

                <span class="token comment">#----------------------------------------#</span>
                <span class="token comment">#   noobj_mask代表无目标的特征点</span>
                <span class="token comment">#----------------------------------------#</span>
                noobj_mask<span class="token punctuation">[</span>b<span class="token punctuation">,</span> best_n<span class="token punctuation">,</span> gj<span class="token punctuation">,</span> gi<span class="token punctuation">]</span> <span class="token operator">=</span> <span class="token number">0</span>
                <span class="token comment">#----------------------------------------#</span>
                <span class="token comment">#   mask代表有目标的特征点</span>
                <span class="token comment">#----------------------------------------#</span>
                mask<span class="token punctuation">[</span>b<span class="token punctuation">,</span> best_n<span class="token punctuation">,</span> gj<span class="token punctuation">,</span> gi<span class="token punctuation">]</span> <span class="token operator">=</span> <span class="token number">1</span>
                <span class="token comment">#----------------------------------------#</span>
                <span class="token comment">#   tx、ty代表中心调整参数的真实值</span>
                <span class="token comment">#----------------------------------------#</span>
                tx<span class="token punctuation">[</span>b<span class="token punctuation">,</span> best_n<span class="token punctuation">,</span> gj<span class="token punctuation">,</span> gi<span class="token punctuation">]</span> <span class="token operator">=</span> gx <span class="token operator">-</span> gi<span class="token punctuation">.</span><span class="token builtin">float</span><span class="token punctuation">(</span><span class="token punctuation">)</span>
                ty<span class="token punctuation">[</span>b<span class="token punctuation">,</span> best_n<span class="token punctuation">,</span> gj<span class="token punctuation">,</span> gi<span class="token punctuation">]</span> <span class="token operator">=</span> gy <span class="token operator">-</span> gj<span class="token punctuation">.</span><span class="token builtin">float</span><span class="token punctuation">(</span><span class="token punctuation">)</span>
                <span class="token comment">#----------------------------------------#</span>
                <span class="token comment">#   tw、th代表宽高调整参数的真实值</span>
                <span class="token comment">#----------------------------------------#</span>
                tw<span class="token punctuation">[</span>b<span class="token punctuation">,</span> best_n<span class="token punctuation">,</span> gj<span class="token punctuation">,</span> gi<span class="token punctuation">]</span> <span class="token operator">=</span> math<span class="token punctuation">.</span>log<span class="token punctuation">(</span>gw <span class="token operator">/</span> anchors<span class="token punctuation">[</span>best_n<span class="token operator">+</span>subtract_index<span class="token punctuation">]</span><span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
                th<span class="token punctuation">[</span>b<span class="token punctuation">,</span> best_n<span class="token punctuation">,</span> gj<span class="token punctuation">,</span> gi<span class="token punctuation">]</span> <span class="token operator">=</span> math<span class="token punctuation">.</span>log<span class="token punctuation">(</span>gh <span class="token operator">/</span> anchors<span class="token punctuation">[</span>best_n<span class="token operator">+</span>subtract_index<span class="token punctuation">]</span><span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
                <span class="token comment">#----------------------------------------#</span>
                <span class="token comment">#   用于获得xywh的比例</span>
                <span class="token comment">#   大目标loss权重小,小目标loss权重大</span>
                <span class="token comment">#----------------------------------------#</span>
                box_loss_scale_x<span class="token punctuation">[</span>b<span class="token punctuation">,</span> best_n<span class="token punctuation">,</span> gj<span class="token punctuation">,</span> gi<span class="token punctuation">]</span> <span class="token operator">=</span> target<span class="token punctuation">[</span>b<span class="token punctuation">]</span><span class="token punctuation">[</span>i<span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">]</span>
                box_loss_scale_y<span class="token punctuation">[</span>b<span class="token punctuation">,</span> best_n<span class="token punctuation">,</span> gj<span class="token punctuation">,</span> gi<span class="token punctuation">]</span> <span class="token operator">=</span> target<span class="token punctuation">[</span>b<span class="token punctuation">]</span><span class="token punctuation">[</span>i<span class="token punctuation">,</span> <span class="token number">3</span><span class="token punctuation">]</span>
                <span class="token comment">#----------------------------------------#</span>
                <span class="token comment">#   tconf代表物体置信度</span>
                <span class="token comment">#----------------------------------------#</span>
                tconf<span class="token punctuation">[</span>b<span class="token punctuation">,</span> best_n<span class="token punctuation">,</span> gj<span class="token punctuation">,</span> gi<span class="token punctuation">]</span> <span class="token operator">=</span> <span class="token number">1</span>
                <span class="token comment">#----------------------------------------#</span>
                <span class="token comment">#   tcls代表种类置信度</span>
                <span class="token comment">#----------------------------------------#</span>
                tcls<span class="token punctuation">[</span>b<span class="token punctuation">,</span> best_n<span class="token punctuation">,</span> gj<span class="token punctuation">,</span> gi<span class="token punctuation">,</span> <span class="token builtin">int</span><span class="token punctuation">(</span>target<span class="token punctuation">[</span>b<span class="token punctuation">]</span><span class="token punctuation">[</span>i<span class="token punctuation">,</span> <span class="token number">4</span><span class="token punctuation">]</span><span class="token punctuation">)</span><span class="token punctuation">]</span> <span class="token operator">=</span> <span class="token number">1</span>
            <span class="token keyword">else</span><span class="token punctuation">:</span>
                <span class="token keyword">print</span><span class="token punctuation">(</span><span class="token string">'Step {0} out of bound'</span><span class="token punctuation">.</span><span class="token builtin">format</span><span class="token punctuation">(</span>b<span class="token punctuation">)</span><span class="token punctuation">)</span>
                <span class="token keyword">print</span><span class="token punctuation">(</span><span class="token string">'gj: {0}, height: {1} | gi: {2}, width: {3}'</span><span class="token punctuation">.</span><span class="token builtin">format</span><span class="token punctuation">(</span>gj<span class="token punctuation">,</span> in_h<span class="token punctuation">,</span> gi<span class="token punctuation">,</span> in_w<span class="token punctuation">)</span><span class="token punctuation">)</span>
                <span class="token keyword">continue</span>

    <span class="token keyword">return</span> mask<span class="token punctuation">,</span> noobj_mask<span class="token punctuation">,</span> tx<span class="token punctuation">,</span> ty<span class="token punctuation">,</span> tw<span class="token punctuation">,</span> th<span class="token punctuation">,</span> tconf<span class="token punctuation">,</span> tcls<span class="token punctuation">,</span> box_loss_scale_x<span class="token punctuation">,</span> box_loss_scale_y

<span class="token keyword">def</span> <span class="token function">get_ignore</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span>prediction<span class="token punctuation">,</span>target<span class="token punctuation">,</span>scaled_anchors<span class="token punctuation">,</span>in_w<span class="token punctuation">,</span> in_h<span class="token punctuation">,</span>noobj_mask<span class="token punctuation">)</span><span class="token punctuation">:</span>
    <span class="token comment">#-----------------------------------------------------#</span>
    <span class="token comment">#   计算一共有多少张图片</span>
    <span class="token comment">#-----------------------------------------------------#</span>
    bs <span class="token operator">=</span> <span class="token builtin">len</span><span class="token punctuation">(</span>target<span class="token punctuation">)</span>
    <span class="token comment">#-------------------------------------------------------#</span>
    <span class="token comment">#   获得当前特征层先验框所属的编号,方便后面对先验框筛选</span>
    <span class="token comment">#-------------------------------------------------------#</span>
    anchor_index <span class="token operator">=</span> <span class="token punctuation">[</span><span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">,</span><span class="token number">1</span><span class="token punctuation">,</span><span class="token number">2</span><span class="token punctuation">]</span><span class="token punctuation">,</span><span class="token punctuation">[</span><span class="token number">3</span><span class="token punctuation">,</span><span class="token number">4</span><span class="token punctuation">,</span><span class="token number">5</span><span class="token punctuation">]</span><span class="token punctuation">,</span><span class="token punctuation">[</span><span class="token number">6</span><span class="token punctuation">,</span><span class="token number">7</span><span class="token punctuation">,</span><span class="token number">8</span><span class="token punctuation">]</span><span class="token punctuation">]</span><span class="token punctuation">[</span>self<span class="token punctuation">.</span>feature_length<span class="token punctuation">.</span>index<span class="token punctuation">(</span>in_w<span class="token punctuation">)</span><span class="token punctuation">]</span>
    scaled_anchors <span class="token operator">=</span> np<span class="token punctuation">.</span>array<span class="token punctuation">(</span>scaled_anchors<span class="token punctuation">)</span><span class="token punctuation">[</span>anchor_index<span class="token punctuation">]</span>

    <span class="token comment"># 先验框的中心位置的调整参数</span>
    x <span class="token operator">=</span> torch<span class="token punctuation">.</span>sigmoid<span class="token punctuation">(</span>prediction<span class="token punctuation">[</span><span class="token punctuation">.</span><span class="token punctuation">.</span><span class="token punctuation">.</span><span class="token punctuation">,</span> <span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">)</span>  
    y <span class="token operator">=</span> torch<span class="token punctuation">.</span>sigmoid<span class="token punctuation">(</span>prediction<span class="token punctuation">[</span><span class="token punctuation">.</span><span class="token punctuation">.</span><span class="token punctuation">.</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
    <span class="token comment"># 先验框的宽高调整参数</span>
    w <span class="token operator">=</span> prediction<span class="token punctuation">[</span><span class="token punctuation">.</span><span class="token punctuation">.</span><span class="token punctuation">.</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">]</span>  <span class="token comment"># Width</span>
    h <span class="token operator">=</span> prediction<span class="token punctuation">[</span><span class="token punctuation">.</span><span class="token punctuation">.</span><span class="token punctuation">.</span><span class="token punctuation">,</span> <span class="token number">3</span><span class="token punctuation">]</span>  <span class="token comment"># Height</span>

    FloatTensor <span class="token operator">=</span> torch<span class="token punctuation">.</span>cuda<span class="token punctuation">.</span>FloatTensor <span class="token keyword">if</span> x<span class="token punctuation">.</span>is_cuda <span class="token keyword">else</span> torch<span class="token punctuation">.</span>FloatTensor
    LongTensor <span class="token operator">=</span> torch<span class="token punctuation">.</span>cuda<span class="token punctuation">.</span>LongTensor <span class="token keyword">if</span> x<span class="token punctuation">.</span>is_cuda <span class="token keyword">else</span> torch<span class="token punctuation">.</span>LongTensor

    <span class="token comment"># 生成网格,先验框中心,网格左上角</span>
    grid_x <span class="token operator">=</span> torch<span class="token punctuation">.</span>linspace<span class="token punctuation">(</span><span class="token number">0</span><span class="token punctuation">,</span> in_w <span class="token operator">-</span> <span class="token number">1</span><span class="token punctuation">,</span> in_w<span class="token punctuation">)</span><span class="token punctuation">.</span>repeat<span class="token punctuation">(</span>in_h<span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">.</span>repeat<span class="token punctuation">(</span>
        <span class="token builtin">int</span><span class="token punctuation">(</span>bs<span class="token operator">*</span>self<span class="token punctuation">.</span>num_anchors<span class="token operator">/</span><span class="token number">3</span><span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">.</span>view<span class="token punctuation">(</span>x<span class="token punctuation">.</span>shape<span class="token punctuation">)</span><span class="token punctuation">.</span><span class="token builtin">type</span><span class="token punctuation">(</span>FloatTensor<span class="token punctuation">)</span>
    grid_y <span class="token operator">=</span> torch<span class="token punctuation">.</span>linspace<span class="token punctuation">(</span><span class="token number">0</span><span class="token punctuation">,</span> in_h <span class="token operator">-</span> <span class="token number">1</span><span class="token punctuation">,</span> in_h<span class="token punctuation">)</span><span class="token punctuation">.</span>repeat<span class="token punctuation">(</span>in_w<span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">.</span>t<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">.</span>repeat<span class="token punctuation">(</span>
        <span class="token builtin">int</span><span class="token punctuation">(</span>bs<span class="token operator">*</span>self<span class="token punctuation">.</span>num_anchors<span class="token operator">/</span><span class="token number">3</span><span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">.</span>view<span class="token punctuation">(</span>y<span class="token punctuation">.</span>shape<span class="token punctuation">)</span><span class="token punctuation">.</span><span class="token builtin">type</span><span class="token punctuation">(</span>FloatTensor<span class="token punctuation">)</span>

    <span class="token comment"># 生成先验框的宽高</span>
    anchor_w <span class="token operator">=</span> FloatTensor<span class="token punctuation">(</span>scaled_anchors<span class="token punctuation">)</span><span class="token punctuation">.</span>index_select<span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">,</span> LongTensor<span class="token punctuation">(</span><span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">)</span><span class="token punctuation">)</span>
    anchor_h <span class="token operator">=</span> FloatTensor<span class="token punctuation">(</span>scaled_anchors<span class="token punctuation">)</span><span class="token punctuation">.</span>index_select<span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">,</span> LongTensor<span class="token punctuation">(</span><span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">)</span><span class="token punctuation">)</span>
    
    anchor_w <span class="token operator">=</span> anchor_w<span class="token punctuation">.</span>repeat<span class="token punctuation">(</span>bs<span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">.</span>repeat<span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">,</span> in_h <span class="token operator">*</span> in_w<span class="token punctuation">)</span><span class="token punctuation">.</span>view<span class="token punctuation">(</span>w<span class="token punctuation">.</span>shape<span class="token punctuation">)</span>
    anchor_h <span class="token operator">=</span> anchor_h<span class="token punctuation">.</span>repeat<span class="token punctuation">(</span>bs<span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">.</span>repeat<span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">,</span> in_h <span class="token operator">*</span> in_w<span class="token punctuation">)</span><span class="token punctuation">.</span>view<span class="token punctuation">(</span>h<span class="token punctuation">.</span>shape<span class="token punctuation">)</span>
    
    <span class="token comment">#-------------------------------------------------------#</span>
    <span class="token comment">#   计算调整后的先验框中心与宽高</span>
    <span class="token comment">#-------------------------------------------------------#</span>
    pred_boxes <span class="token operator">=</span> FloatTensor<span class="token punctuation">(</span>prediction<span class="token punctuation">[</span><span class="token punctuation">.</span><span class="token punctuation">.</span><span class="token punctuation">.</span><span class="token punctuation">,</span> <span class="token punctuation">:</span><span class="token number">4</span><span class="token punctuation">]</span><span class="token punctuation">.</span>shape<span class="token punctuation">)</span>
    pred_boxes<span class="token punctuation">[</span><span class="token punctuation">.</span><span class="token punctuation">.</span><span class="token punctuation">.</span><span class="token punctuation">,</span> <span class="token number">0</span><span class="token punctuation">]</span> <span class="token operator">=</span> x<span class="token punctuation">.</span>data <span class="token operator">+</span> grid_x
    pred_boxes<span class="token punctuation">[</span><span class="token punctuation">.</span><span class="token punctuation">.</span><span class="token punctuation">.</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">]</span> <span class="token operator">=</span> y<span class="token punctuation">.</span>data <span class="token operator">+</span> grid_y
    pred_boxes<span class="token punctuation">[</span><span class="token punctuation">.</span><span class="token punctuation">.</span><span class="token punctuation">.</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">]</span> <span class="token operator">=</span> torch<span class="token punctuation">.</span>exp<span class="token punctuation">(</span>w<span class="token punctuation">.</span>data<span class="token punctuation">)</span> <span class="token operator">*</span> anchor_w
    pred_boxes<span class="token punctuation">[</span><span class="token punctuation">.</span><span class="token punctuation">.</span><span class="token punctuation">.</span><span class="token punctuation">,</span> <span class="token number">3</span><span class="token punctuation">]</span> <span class="token operator">=</span> torch<span class="token punctuation">.</span>exp<span class="token punctuation">(</span>h<span class="token punctuation">.</span>data<span class="token punctuation">)</span> <span class="token operator">*</span> anchor_h

    <span class="token keyword">for</span> i <span class="token keyword">in</span> <span class="token builtin">range</span><span class="token punctuation">(</span>bs<span class="token punctuation">)</span><span class="token punctuation">:</span>
        pred_boxes_for_ignore <span class="token operator">=</span> pred_boxes<span class="token punctuation">[</span>i<span class="token punctuation">]</span>
        <span class="token comment">#-------------------------------------------------------#</span>
        <span class="token comment">#   将预测结果转换一个形式</span>
        <span class="token comment">#   pred_boxes_for_ignore      num_anchors, 4</span>
        <span class="token comment">#-------------------------------------------------------#</span>
        pred_boxes_for_ignore <span class="token operator">=</span> pred_boxes_for_ignore<span class="token punctuation">.</span>view<span class="token punctuation">(</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">4</span><span class="token punctuation">)</span>
        <span class="token comment">#-------------------------------------------------------#</span>
        <span class="token comment">#   计算真实框,并把真实框转换成相对于特征层的大小</span>
        <span class="token comment">#   gt_box      num_true_box, 4</span>
        <span class="token comment">#-------------------------------------------------------#</span>
        <span class="token keyword">if</span> <span class="token builtin">len</span><span class="token punctuation">(</span>target<span class="token punctuation">[</span>i<span class="token punctuation">]</span><span class="token punctuation">)</span> <span class="token operator">&gt;</span> <span class="token number">0</span><span class="token punctuation">:</span>
            gx <span class="token operator">=</span> target<span class="token punctuation">[</span>i<span class="token punctuation">]</span><span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token number">0</span><span class="token punctuation">:</span><span class="token number">1</span><span class="token punctuation">]</span> <span class="token operator">*</span> in_w
            gy <span class="token operator">=</span> target<span class="token punctuation">[</span>i<span class="token punctuation">]</span><span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">:</span><span class="token number">2</span><span class="token punctuation">]</span> <span class="token operator">*</span> in_h
            gw <span class="token operator">=</span> target<span class="token punctuation">[</span>i<span class="token punctuation">]</span><span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">:</span><span class="token number">3</span><span class="token punctuation">]</span> <span class="token operator">*</span> in_w
            gh <span class="token operator">=</span> target<span class="token punctuation">[</span>i<span class="token punctuation">]</span><span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token number">3</span><span class="token punctuation">:</span><span class="token number">4</span><span class="token punctuation">]</span> <span class="token operator">*</span> in_h
            gt_box <span class="token operator">=</span> torch<span class="token punctuation">.</span>FloatTensor<span class="token punctuation">(</span>torch<span class="token punctuation">.</span>cat<span class="token punctuation">(</span><span class="token punctuation">[</span>gx<span class="token punctuation">,</span> gy<span class="token punctuation">,</span> gw<span class="token punctuation">,</span> gh<span class="token punctuation">]</span><span class="token punctuation">,</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">.</span><span class="token builtin">type</span><span class="token punctuation">(</span>FloatTensor<span class="token punctuation">)</span>

            <span class="token comment">#-------------------------------------------------------#</span>
            <span class="token comment">#   计算交并比</span>
            <span class="token comment">#   anch_ious       num_true_box, num_anchors</span>
            <span class="token comment">#-------------------------------------------------------#</span>
            anch_ious <span class="token operator">=</span> jaccard<span class="token punctuation">(</span>gt_box<span class="token punctuation">,</span> pred_boxes_for_ignore<span class="token punctuation">)</span>
            <span class="token comment">#-------------------------------------------------------#</span>
            <span class="token comment">#   每个先验框对应真实框的最大重合度</span>
            <span class="token comment">#   anch_ious_max   num_anchors</span>
            <span class="token comment">#-------------------------------------------------------#</span>
            anch_ious_max<span class="token punctuation">,</span> _ <span class="token operator">=</span> torch<span class="token punctuation">.</span><span class="token builtin">max</span><span class="token punctuation">(</span>anch_ious<span class="token punctuation">,</span>dim<span class="token operator">=</span><span class="token number">0</span><span class="token punctuation">)</span>
            anch_ious_max <span class="token operator">=</span> anch_ious_max<span class="token punctuation">.</span>view<span class="token punctuation">(</span>pred_boxes<span class="token punctuation">[</span>i<span class="token punctuation">]</span><span class="token punctuation">.</span>size<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token number">3</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
            noobj_mask<span class="token punctuation">[</span>i<span class="token punctuation">]</span><span class="token punctuation">[</span>anch_ious_max<span class="token operator">&gt;</span>self<span class="token punctuation">.</span>ignore_threshold<span class="token punctuation">]</span> <span class="token operator">=</span> <span class="token number">0</span>
    <span class="token keyword">return</span> noobj_mask
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162
  • 163
  • 164
  • 165
  • 166
  • 167
  • 168
  • 169
  • 170
  • 171
  • 172
  • 173
  • 174
  • 175
  • 176
  • 177
  • 178
  • 179
  • 180
  • 181
  • 182
  • 183
  • 184
  • 185
  • 186
  • 187
  • 188
  • 189
  • 190
  • 191
  • 192
  • 193
  • 194
  • 195
  • 196
  • 197
  • 198
  • 199
  • 200
  • 201
  • 202
  • 203
  • 204
  • 205
  • 206
  • 207
  • 208
  • 209
  • 210
  • 211
  • 212
  • 213
  • 214
  • 215
  • 216
  • 217
  • 218
  • 219
  • 220
  • 221
  • 222
  • 223
  • 224
  • 225
  • 226
  • 227
  • 228
  • 229
  • 230
  • 231
  • 232
  • 233
  • 234
  • 235
  • 236
  • 237
  • 238
  • 239
  • 240
  • 241
  • 242
  • 243
  • 244
  • 245
  • 246
  • 247
  • 248
  • 249
  • 250
  • 251
  • 252
  • 253
  • 254
  • 255
  • 256
  • 257
  • 258
  • 259
  • 260
  • 261
  • 262
  • 263
  • 264
  • 265
  • 266
  • 267
  • 268
  • 269
  • 270
  • 271
  • 272
  • 273
  • 274
  • 275
  • 276
  • 277
  • 278
  • 279
  • 280
  • 281
  • 282
  • 283
  • 284
  • 285
  • 286
  • 287
  • 288
  • 289
  • 290
  • 291
  • 292
  • 293
  • 294
  • 295
  • 296
  • 297
  • 298
  • 299
  • 300
  • 301
  • 302
  • 303
  • 304
  • 305
  • 306
  • 307
  • 308
  • 309
  • 310
  • 311
  • 312
  • 313
  • 314
  • 315
  • 316
  • 317
  • 318
  • 319
  • 320
  • 321
  • 322
  • 323
  • 324
  • 325
  • 326
  • 327
  • 328
  • 329
  • 330
  • 331
  • 332
  • 333
  • 334
  • 335
  • 336
  • 337
  • 338
  • 339
  • 340
  • 341
  • 342
  • 343
  • 344
  • 345
  • 346
  • 347
  • 348
  • 349
  • 350
  • 351
  • 352
  • 353
  • 354
  • 355
  • 356
  • 357
  • 358
  • 359
  • 360
  • 361
  • 362
  • 363
  • 364
  • 365
  • 366
  • 367
  • 368
  • 369
  • 370
  • 371
  • 372
  • 373
  • 374
  • 375
  • 376
  • 377
  • 378
  • 379
  • 380
  • 381
  • 382
  • 383
  • 384
  • 385
  • 386
  • 387
  • 388
  • 389
  • 390
  • 391
  • 392
  • 393
  • 394
  • 395

训练自己的yolo3模型

yolo3整体的文件夹构架如下:
在这里插入图片描述
本文使用VOC格式进行训练。
训练前将标签文件放在VOCdevkit文件夹下的VOC2007文件夹下的Annotation中。
在这里插入图片描述
训练前将图片文件放在VOCdevkit文件夹下的VOC2007文件夹下的JPEGImages中。
在这里插入图片描述
在训练前利用voc2yolo3.py文件生成对应的txt。
在这里插入图片描述
再运行根目录下的voc_annotation.py,运行前需要将classes改成你自己的classes。

classes = ["aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor"]

 
 
  • 1

在这里插入图片描述
就会生成对应的2007_train.txt,每一行对应其图片位置及其真实框的位置。
在这里插入图片描述
在训练前需要修改model_data里面的voc_classes.txt文件,需要将classes改成你自己的classes。同时还需要修改train.py文件,修改内部的num_classes变成所分的种类的数量。
在这里插入图片描述

运行train.py即可开始训练。
在这里插入图片描述

# 欢迎使用Markdown编辑器

你好! 这是你第一次使用 Markdown编辑器 所展示的欢迎页。如果你想学习如何使用Markdown编辑器, 可以仔细阅读这篇文章,了解一下Markdown的基本语法知识。

新的改变

我们对Markdown编辑器进行了一些功能拓展与语法支持,除了标准的Markdown编辑器功能,我们增加了如下几点新功能,帮助你用它写博客:

  1. 全新的界面设计 ,将会带来全新的写作体验;
  2. 在创作中心设置你喜爱的代码高亮样式,Markdown 将代码片显示选择的高亮样式 进行展示;
  3. 增加了 图片拖拽 功能,你可以将本地的图片直接拖拽到编辑区域直接展示;
  4. 全新的 KaTeX数学公式 语法;
  5. 增加了支持甘特图的mermaid语法1 功能;
  6. 增加了 多屏幕编辑 Markdown文章功能;
  7. 增加了 焦点写作模式、预览模式、简洁写作模式、左右区域同步滚轮设置 等功能,功能按钮位于编辑区域与预览区域中间;
  8. 增加了 检查列表 功能。

功能快捷键

撤销:Ctrl/Command + Z
重做:Ctrl/Command + Y
加粗:Ctrl/Command + B
斜体:Ctrl/Command + I
标题:Ctrl/Command + Shift + H
无序列表:Ctrl/Command + Shift + U
有序列表:Ctrl/Command + Shift + O
检查列表:Ctrl/Command + Shift + C
插入代码:Ctrl/Command + Shift + K
插入链接:Ctrl/Command + Shift + L
插入图片:Ctrl/Command + Shift + G
查找:Ctrl/Command + F
替换:Ctrl/Command + G

合理的创建标题,有助于目录的生成

直接输入1次#,并按下space后,将生成1级标题。
输入2次#,并按下space后,将生成2级标题。
以此类推,我们支持6级标题。有助于使用TOC语法后生成一个完美的目录。

如何改变文本的样式

强调文本 强调文本

加粗文本 加粗文本

标记文本

删除文本

引用文本

H2O is是液体。

210 运算结果是 1024.

插入链接与图片

链接: link.

图片: Alt

带尺寸的图片: Alt

居中的图片: Alt

居中并且带尺寸的图片: Alt

当然,我们为了让用户更加便捷,我们增加了图片拖拽功能。

如何插入一段漂亮的代码片

博客设置页面,选择一款你喜欢的代码片高亮样式,下面展示同样高亮的 代码片.

// An highlighted block
var foo = 'bar';

生成一个适合你的列表

  • 项目
    • 项目
      • 项目
  1. 项目1
  2. 项目2
  3. 项目3
  • 计划任务
  • 完成任务

创建一个表格

一个简单的表格是这么创建的:

项目Value
电脑$1600
手机$12
导管$1

设定内容居中、居左、居右

使用:---------:居中
使用:----------居左
使用----------:居右

第一列第二列第三列
第一列文本居中第二列文本居右第三列文本居左

SmartyPants

SmartyPants将ASCII标点字符转换为“智能”印刷标点HTML实体。例如:

TYPEASCIIHTML
Single backticks'Isn't this fun?'‘Isn’t this fun?’
Quotes"Isn't this fun?"“Isn’t this fun?”
Dashes-- is en-dash, --- is em-dash– is en-dash, — is em-dash

创建一个自定义列表

Markdown
Text-to- HTML conversion tool
Authors
John
Luke

如何创建一个注脚

一个具有注脚的文本。2

注释也是必不可少的

Markdown将文本转换为 HTML

KaTeX数学公式

您可以使用渲染LaTeX数学表达式 KaTeX:

Gamma公式展示 Γ ( n ) = ( n − 1 ) ! ∀ n ∈ N \Gamma(n) = (n-1)!\quad\forall n\in\mathbb N Γ(n)=(n1)!nN 是通过欧拉积分

Γ ( z ) = ∫ 0 ∞ t z − 1 e − t d t   . \Gamma(z) = \int_0^\infty t^{z-1}e^{-t}dt\,. Γ(z)=0tz1etdt.

你可以找到更多关于的信息 LaTeX 数学表达式here.

新的甘特图功能,丰富你的文章

Mon 06 Mon 13 Mon 20 已完成 进行中 计划一 计划二 现有任务 Adding GANTT diagram functionality to mermaid
  • 关于 甘特图 语法,参考 这儿,

UML 图表

可以使用UML图表进行渲染。 Mermaid. 例如下面产生的一个序列图:

张三 李四 王五 你好!李四, 最近怎么样? 你最近怎么样,王五? 我很好,谢谢! 我很好,谢谢! 李四想了很长时间, 文字太长了 不适合放在一行. 打量着王五... 很好... 王五, 你怎么样? 张三 李四 王五

这将产生一个流程图。:

链接
长方形
圆角长方形
菱形
  • 关于 Mermaid 语法,参考 这儿,

FLowchart流程图

我们依旧会支持flowchart的流程图:

Created with Raphaël 2.2.0 开始 我的操作 确认? 结束 yes no
  • 关于 Flowchart流程图 语法,参考 这儿.

导出与导入

导出

如果你想尝试使用此编辑器, 你可以在此篇文章任意编辑。当你完成了一篇文章的写作, 在上方工具栏找到 文章导出 ,生成一个.md文件或者.html文件进行本地保存。

导入

如果你想加载一篇你写过的.md文件,在上方工具栏可以选择导入功能进行对应扩展名的文件导入,
继续你的创作。


  1. mermaid语法说明 ↩︎

  2. 注脚的解释 ↩︎

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值