Fairmot理解与MindSpore框架下的实现

1.Fairmot介绍

FairMOT是由华中科技大学和微软亚洲研究院提出的多目标跟踪(MOT)baseline,作者分析现存one-shot目标追踪算法的问题所在,提出了三个观点:
(1)anchors对于Re-ID并不友好,应该采用anchor-free算法。
(2)多层特征的融合。
(3)对于one-shot方法,Re-ID的特征向量采用低维度更好。
Fairmot网络架构在MOT15、MOT16、MOT17、MOT20等数据集上以30fps的帧数达到了目前的SOTA水平。
多目标跟踪一直是计算机视觉的一个长期目标,目标是估计视频中多个目标的轨迹,该任务的成功解决将有利于许多应用,如动作识别、运动视频分析、老年护理和人机交互。现存的SOTA方法当中大部分都是采用two-step方法两步走:
(1)通过目标检测算法检测到目标。
(2)再经过Re-ID模型进行匹配并根据特征上定义的特定度量将其链接到一个现有的轨迹。
尽管随着近年来目标检测算法与Re-ID的发展,two-step方法在目标跟踪上也有明显的性能提升,但是two-step方法不会共享检测算法与Re-ID的特征图,所以其速度很慢,很难在视频速率下进行推理。随着two-step方法的成熟,更多的研究人员开始研究同时检测目标和学习Re-ID特征的one-shot算法,当特征图在目标检测与Re-ID之间共享之后,可以大大的减少推理时间,但在精度上就会比two-step方法低很多。所以作者针对one-shot方法进行分析,提出了上述三个方面的因素。
一些SOTA的跟踪算法通常是two-step算法,他们将检测目标和Re-ID分成了两个任务:
(1)首先通过检测算法获取到物体的位置(预测框)。
(2)将预测的物体裁剪下来进行缩放传入身份特征提取器来获取Re-ID特征,连接框形成多条轨迹。
连接框形成轨迹的标准做法就是:根据Re-ID特征和框的IOU来计算一个代价矩阵,再利用卡尔曼滤波和匈牙利算法实现连接轨迹的任务。有一小部分研究使用了更复杂的关联策略,如群体模型和RNNs。
two-step方法的好处就是,可以在两个任务当中分别使用合适的模型,并且可以将预测的框进行裁剪和缩放传入Re-ID特征提取器当中,有助于处理对象比例变化。并且跟踪效果也很好,但是速度很慢,难以以视频速率进行推理。
One-shot方法核心思想是在一个网络中同时完成目标检测和身份嵌入(Re-ID feature),通过共享大部分计算量来减少推理时间。
(1)Track-RCNN通过添加一个Re-ID head的部分为每个候选区域来回归框和Re-ID的部分。
(2)JDE则是实现在YOLOV3框架的基础上并实现了视频速率的推理。
然而,单步one-shot方法的跟踪精度往往低于two-step跟踪方法。论文发现这是因为学习的ReID特性不是最优的,这导致了大量的ID切换。

2.细节

Fairmot框架如下,首先将输入图像送入编码器-解码器网络,以提取高分辨率特征图(步长=4);然后添加两个简单的并行 head,分别预测边界框和 Re-ID 特征;最后提取预测目标中心处的特征进行边界框时序联结。

Fairmot框架

采用 anchor-free 目标检测方法,估计高分辨率特征图上的目标中心。去掉锚点这一操作可以缓解歧义问题,使用高分辨率特征图可以帮助 Re-ID 特征与目标中心更好地对齐。
添加并行分支来估计像素级 Re-ID 特征,这类特征用于预测目标的 ID。具体而言,学习既能减少计算时间又能提升特征匹配稳健性的低维 Re-ID 特征。在这一步中,Fairmot用深层聚合算子(Deep Layer Aggregation,DLA)来改进主干网络 ResNet-34 ,从而融合来自多个层的特征,处理不同尺度的目标。

主干网络

采用ResNet-34 作为主干网络,以便在准确性和速度之间取得良好的平衡。为了适应不同规模的对象,将深层聚合(DLA)的一种变体应用于主干网络。
与原始DLA 不同,它在低层聚合和低层聚合之间具有更多的跳跃连接,类似于特征金字塔网络(FPN)。此外,上采样模块中的所有卷积层都由可变形的卷积层代替,以便它们可以根据对象的尺寸和姿势动态调整感受野。 这些修改也有助于减轻对齐问题。

物体检测分支

Fairmot将目标检测视为高分辨率特征图上基于中心的包围盒回归任务。特别是将三个并行回归头(regression heads)附加到主干网络以分别估计热图,对象中心偏移和边界框大小。 通过对主干网络的输出特征图应用3×3卷积(具有256个通道)来实现每个回归头(head),然后通过1×1卷积层生成最终目标。
(1)Heatmap Head:该head负责估计对象中心的位置。这里采用基于热图的表示法,热图的尺寸为1×H×W。 随着热图中位置和对象中心之间的距离,响应呈指数衰减。
(2)Center Offset Head:该head负责更精确地定位对象。ReID功能与对象中心的对齐精准度对于性能至关重要。
(3)Box Size Head:该部分负责估计每个锚点位置的目标边界框的高度和宽度,与Re-ID功能没有直接关系,但是定位精度将影响对象检测性能的评估。

ID嵌入分支 Identity Embedding Branch

id嵌入分支的目标是生成可以区分不同对象的特征。理想情况下,不同对象之间的距离应大于同一对象之间的距离。为了实现该目标,Fairmot在主干特征之上应用了具有128个内核的卷积层,以提取每个位置的身份嵌入特征。

损失函数

(1)Heatmap loss:Fairmot按照高斯分布将物体的中心映射到了heatmap上,然后使用变形的focal loss进行预测的heatmap和实际真实的heatmap损失函数的求解,公式如下:
L h e a t m a p = − 1 N ∑ x y { ( 1 − M ^ x y ) α log ⁡ ( M ^ x y ) , i f    M x y = 1 ( 1 − M ^ x y ) β ( M ^ x y ) α log ⁡ ( 1 − M ^ x y ) , o t h e r w i s e L_{heatmap}=-\frac{1}{N}\sum_{xy}{\left\{\begin{array}{c} \left(1-\hat{M}_{xy}\right)^{\alpha}\log\left(\hat{M}_{xy}\right),if\,\,M_{xy}=1\\ \left(1-\hat{M}_{xy}\right)^{\beta}\left(\hat{M}_{xy}\right)^{\alpha}\log\left(1-\hat{M}_{xy}\right),otherwise\\ \end{array}\right.} Lheatmap=N1xy (1M^xy)αlog(M^xy),ifMxy=1(1M^xy)β(M^xy)αlog(1M^xy),otherwise
M ^ x y \hat{M}_{xy} M^xy是预测的heatmap特征图, M x y M_xy Mxy是heatmap的ground-truth, N N N为一个图中物体总数量。
(2)Offset and Size loss:Fairmot用了两个L1损失就实现了Offset和Size损失:
L b o x = ∑ i = 1 N ∥ o i − o ^ i − o ^ i ∥ 1 − ∥ S i − S ^ i ∥ 1 L_{box}=\sum_{i=1}^N{\left\|o^i-\hat{o}^i-\hat{o}^i\right\|_1-\left\|S^i-\hat{S}^i\right\|_1} Lbox=i=1N oio^io^i 1 SiS^i 1
其中, N N N为一个图中物体总数量, S S S表示Size 框的大小, O O O表示Offset 中心点的偏差。
(3)Identity Embedding Loss:FairMOT中的Embedding也是需要借助分类(按照物体ID为不同物体分配不同的类别)进行学习的。其中分类用到softmax损失:
L i d e n t i t y = − ∑ i = 1 N ∑ k = 1 K L i ( k ) log ⁡ ( p ( k ) ) L_{identity}=-\sum_{i=1}^N{\sum_{k=1}^K{L^i\left(k\right)\log\left(p\left(k\right)\right)}} Lidentity=i=1Nk=1KLi(k)log(p(k))
其中, N N N为一个图中物体总数量, K K K是类别数量。即,这部分需要对图片中每个物体进行分类识别,这里分类识别是具体认识到是指那一个物体,具有相同身份的所有对象实例都被视为一个类。

3. 环境准备

本案例基于MindSpore实现,开始实验前,请确保本地已经安装了mindspore、download、pycocotools、opencv-python、Cython、cython-bbox、decord等环境和python库

并且安装mindvideo安装包:

git clone https://gitee.com/yanlq46462828/zjut_mindvideo.git
cd zjut_mindvideo

Please first install mindspore according to instructions on the official website: https://www.mindspore.cn/install

pip install -r requirements.txt

pip install -e .

4. 数据准备与处理

FairMot 模型使用混合数据集在此存储库中进行训练和验证。 我们在这部分使用训练数据作为 JDE,我们称之为“MIX”。请参考他们的 DATA ZOO 下载并准备所有训练数据,包括 Caltech Pedestrian、CityPersons、CUHK-SYSU、PRW、ETHZ、MOT17 和 MOT16。然后将所有训练和评估数据放入一个目录,然后将 data.json 中的“data_root”更改为该目录,如下所示:

pip install -r requirements.txt

然后,使用目录 ./src/data/builder.py文件中的build_transforms函数对视频进行transforms pipeline:
def build_transforms(cfg):
“”" build data transform pipeline. “”"
cfg_pipeline = cfg
if not isinstance(cfg_pipeline, list):
return ClassFactory.get_instance_from_cfg(cfg_pipeline,
ModuleType.PIPELINE)

transforms = []
for transform in cfg_pipeline:
    transform_op = build_transforms(transform)
    transforms.append(transform_op)

return transforms

5.使用说明

在基于Mindspore框架下的Fairmot的baseline代码如下:
from typing import Tuple, Union
import numpy as np
from mindspore import nn
from mindspore import ops

from src.utils.check_param import Rel, Validator
from src.utils.class_factory import ClassFactory, ModuleType
from src.models.layers import DeformConv2d, FairMOTMultiHead

class BasicBlock(nn.Cell):
“”"
DLA中的残差快
“”"

def __init__(self, cin, cout, stride=1, dilation=1):
    super(BasicBlock, self).__init__()
    self.conv_bn_act=nn.Conv2dBnAct(cin,cout,kernel_size=3, stride=stride, pad_mode='pad', padding=dilation, has_bias=False,         dilation=dilation, has_bn=True, momentum=0.9,ctivation='relu', after_fake=False)
    self.conv_bn = nn.Conv2dBnAct(cout, cout, kernel_size=3, stride=1, pad_mode='same',
                                  has_bias=False, dilation=dilation, has_bn=True,
                                  momentum=0.9, activation=None)
    self.relu = ops.ReLU()

def construct(self, x, residual=None):
    if residual is None:
        residual = x
    out = self.conv_bn_act(x)
    out = self.conv_bn(out)
    out += residual
    out = self.relu(out)
    return out

class Root(nn.Cell):
“”"
获取HDA节点
“”"

def __init__(self, in_channels, out_channels, kernel_size, residual):
    super(Root, self).__init__()
    self.conv = nn.Conv2d(in_channels, out_channels, 1, stride=1, has_bias=False,
                          pad_mode='pad', padding=(kernel_size - 1) // 2)
    self.bn = nn.BatchNorm2d(out_channels)
    self.relu = ops.ReLU()
    self.residual = residual
    self.cat = ops.Concat(axis=1)

def construct(self, x):
    children = x
    x = self.conv(self.cat(x))
    x = self.bn(x)
    if self.residual:
        x += children[0]
    x = self.relu(x)
    return x

class Tree(nn.Cell):
“”"
构建深度聚合网络.
“”"

def __init__(self, levels, block, in_channels, out_channels, stride=1, level_root=False,
             root_dim=0, root_kernel_size=1, dilation=1, root_residual=False):
    super(Tree, self).__init__()
    self.levels = levels
    if root_dim == 0:
        root_dim = 2 * out_channels
    if level_root:
        root_dim += in_channels
    if self.levels == 1:
        self.tree1 = block(in_channels, out_channels, stride, dilation=dilation)
        self.tree2 = block(out_channels, out_channels, 1, dilation=dilation)
    else:
        self.tree1 = Tree(levels - 1, block, in_channels, out_channels, stride, root_dim=0,
                          root_kernel_size=root_kernel_size,dilation=dilation, root_residual=root_residual)
        self.tree2 = Tree(levels - 1, block, out_channels, out_channels, root_dim=root_dim +out_channels,root_kernel_size=root_kernel_size, dilation=dilation, root_residual=root_residual)
    if self.levels == 1:
        self.root = Root(root_dim, out_channels, root_kernel_size, root_residual)
    self.level_root = level_root
    self.root_dim = root_dim
    self.downsample = None
    self.project = None
    if stride > 1:
        self.downsample = nn.MaxPool2d(stride, stride=stride)
    if in_channels != out_channels:
        self.project = nn.Conv2dBnAct(in_channels, out_channels, kernel_size=1, stride=1, pad_mode='same',has_bias=False, has_bn=True, momentum=0.9,
                    activation=None, after_fake=False)

def construct(self, x, residual=None, children=None):

    children = () if children is None else children
    bottom = self.downsample(x) if self.downsample else x
    residual = self.project(bottom) if self.project else bottom
    if self.level_root:
        children += (bottom,)
    x1 = self.tree1(x, residual)
    if self.levels == 1:
        x2 = self.tree2(x1)
        ida_node = (x2, x1) + children
        x = self.root(ida_node)
    else:
        children += (x1,)
        x = self.tree2(x1, children=children)
    return x

class DLA34(nn.Cell):
“”"
构建下采样深度聚合网络
“”"

def __init__(self, levels, channels, block=None, residual_root=False):
    super(DLA34, self).__init__()
    self.channels = channels
    self.base_layer=nn.Conv2dBnAct(3, channels[0], kernel_size=7, stride=1, pad_mode='same',has_bias=False, has_bn=True, momentum=0.9, activation='relu', after_fake=False)
    self.level0 = self._make_conv_level(channels[0], channels[0], levels[0])
    self.level1 = self._make_conv_level(channels[0], channels[1], levels[1], stride=2)
    self.level2 = Tree(levels[2], block, channels[1], channels[2], 2,
                       level_root=False, root_residual=residual_root)
    self.level3 = Tree(levels[3], block, channels[2], channels[3], 2,
                       level_root=True, root_residual=residual_root)
    self.level4 = Tree(levels[4], block, channels[3], channels[4], 2,
                       level_root=True, root_residual=residual_root)
    self.level5 = Tree(levels[5], block, channels[4], channels[5], 2,
                       level_root=True, root_residual=residual_root)
    self.dla_fn = [self.level0, self.level1, self.level2, self.level3, self.level4, self.level5]

def _make_conv_level(self, cin, cout, convs, stride=1, dilation=1):
    modules = []
    for i in range(convs):
        modules.append(nn.Conv2dBnAct(cin, cout, kernel_size=3, stride=stride if i == 0 else 1, pad_mode='pad', padding=dilation, has_bias=False, dilation=dilation, has_bn=True, momentum=0.9, activation='relu', after_fake=False))
        cin = cout
    return nn.SequentialCell(modules)

def construct(self, x):
    y = []
    x = self.base_layer(x)
    for i in range(len(self.channels)):
        x = self.dla_fn[i](x)
        y.append(x)
    return y

class DlaDeformConv(nn.Cell):
“”"
具有bn和relu的可变形卷积v2。.
“”"

def __init__(self, cin, cout):
    super(DlaDeformConv, self).__init__()
    self.actf = nn.SequentialCell([
        nn.BatchNorm2d(cout),
        nn.ReLU()])
    self.conv = DeformConv2d(cin, cout, kernel_size=3, stride=1, has_bias=True)

def construct(self, x):
    x = self.conv(x)
    x = self.actf(x)
    return x

class IDAUp(nn.Cell):
“”“IDA上采样.”“”

def __init__(self, o, channels, up_f):
    super(IDAUp, self).__init__()
    proj_list = []
    up_list = []
    node_list = []
    for i in range(1, len(channels)):
        c = channels[i]
        f = int(up_f[i])
        proj = DlaDeformConv(c, o)
        node = DlaDeformConv(o, o)
        up = nn.Conv2dTranspose(o, o, f * 2, stride=f, pad_mode='pad', padding=f // 2,                                                                                                                                  group=o)
        proj_list.append(proj)
        up_list.append(up)
        node_list.append(node)
    self.proj = nn.CellList(proj_list)
    self.up = nn.CellList(up_list)
    self.node = nn.CellList(node_list)

def construct(self, layers, startp, endp):
    for i in range(startp + 1, endp):
        upsample = self.up[i - startp - 1]
        project = self.proj[i - startp - 1]
        layers[i] = upsample(project(layers[i]))
        node = self.node[i - startp - 1]
        layers[i] = node(layers[i] + layers[i - 1])
    return layers

class DLAUp(nn.Cell):
“”“DLA上采样.”“”
def init(self, startp, channels, scales, in_channels=None):
super(DLAUp, self).init()
self.startp = startp
channels = list(channels)
if in_channels is None:
in_channels = list(channels)
scales = np.array(scales, dtype=int)
self.ida = []
for i in range(len(channels) - 1):
j = -i - 2
self.ida.append(IDAUp(channels[j], in_channels[j:],
scales[j:] // scales[j]))
scales[j + 1:] = scales[j]
in_channels[j + 1:] = [channels[j] for _ in channels[j + 1:]]
self.ida_nfs = nn.CellList(self.ida)

def construct(self, layers):
    out = [layers[-1]]  # start with 32
    for i in range(len(layers) - self.startp - 1):
        ida = self.ida_nfs[i]
        layers = ida(layers, len(layers) - i - 2, len(layers))
        out.append(layers[-1])
    a = []
    i = len(out)
    while i > 0:
        a.append(out[i - 1])
        i -= 1
    return a

@ClassFactory.register(ModuleType.MODEL)
class DLASegConv(nn.Cell):
“”"
DLA的backbone网络
“”"

def __init__(self,
             down_ratio: int,
             last_level: int,
             out_channel: int = 0,
             stage_levels: Tuple[int] = (1, 1, 1, 2, 2, 1),
             stage_channels: Tuple[int] = (16, 32, 64, 128, 256, 512)):
    super(DLASegConv, self).__init__()
    Validator.check('down_ratio', down_ratio, 'given_ratio', [2, 4, 8, 16], rel=Rel.IN)
    self.first_level = int(np.log2(down_ratio))
    self.last_level = last_level
    self.base = DLA34(stage_levels, stage_channels, block=BasicBlock)
    channels = stage_channels
    scales = [2 ** i for i in range(len(channels[self.first_level:]))]
    self.dla_up = DLAUp(self.first_level, channels[self.first_level:], scales)
    if out_channel == 0:
        out_channel = channels[self.first_level]
    self.ida_up = IDAUp(out_channel, channels[self.first_level:self.last_level],
                        [2 ** i for i in range(self.last_level - self.first_level)])

def construct(self, image):
    x = self.base(image)
    x = self.dla_up(x)
    y = []
    for i in range(self.last_level - self.first_level):
        y.append(x[i])
    y = self.ida_up(y, 0, len(y))
    return y[-1]

@ClassFactory.register(ModuleType.MODEL)
class FairmotDla34(nn.Cell):
“”"
TODO: Fairmot网络.
“”"

def __init__(self,
             down_ratio: int = 4,
             last_level: int = 5,
             head_channel: int = 256,
             head_conv2_ksize: Union[int, Tuple[int]] = 1,
             hm: int = 1,
             wh: int = 4,
             feature_id: int = 128,
             reg: int = 2):
    super().__init__()
    backbone_output_channel = 64
    self.backbone = DLASegConv(down_ratio=down_ratio,
                               last_level=last_level)
    self.head = FairMOTMultiHead(heads={'hm': hm, 'wh': wh, 'feature_id': feature_id, 'reg': reg},in_channel=backbone_output_channel,head_conv=head_channel,kernel_size=head_conv2_ksize)

def construct(self, x):
    x = self.backbone(x)
    x = self.head(x)
    return x

6. 训练过程

(1)首先,准备数据集。如上述所说,准备“MIX”数据集作为训练。随后,定义先验框。最后,对数据集进行数据增强。整体流程可见代码,如下:

perpare dataset

transforms = build_transforms(config.data_loader.train.map.operations)
data_set = build_dataset(config.data_loader.train.dataset)
data_set.transform = transforms
dataset_train = data_set.run()
Validator.check_int(dataset_train.get_dataset_size(), 0, Rel.GT)
batches_per_epoch = dataset_train.get_dataset_size()

(2)设定网络框架。网络框架如上述所述,因此这里直接定义网络架构。并且定义loss,学习率、优化器等超参数:

set network

network = build_model(config.model)

# set loss
network_loss = build_loss(config.loss)
# set lr
lr_cfg = config.learning_rate
lr_cfg.steps_per_epoch = int(batches_per_epoch / config.data_loader.group_size)
lr = get_lr(lr_cfg)

# set optimizer
config.optimizer.params = network.trainable_params()
config.optimizer.learning_rate = lr
network_opt = build_optimizer(config.optimizer)

if config.train.pre_trained:
    # load pretrain model
    param_dict = load_checkpoint(config.train.pretrained_model)
    load_param_into_net(network, param_dict)

# set checkpoint for the network
ckpt_config = CheckpointConfig(
    save_checkpoint_steps=config.train.save_checkpoint_steps,
    keep_checkpoint_max=config.train.keep_checkpoint_max)
ckpt_callback = ModelCheckpoint(prefix=config.model_name,
                                directory=ckpt_save_dir,
                                config=ckpt_config)

# init the whole Model
model = Model(network,
              network_loss,
              network_opt,
              metrics={"Accuracy": Accuracy()})

(3)最后,开始训练。

Loading cuhksysu…
cuhksysu loaded.
Loading caltech…
caltech loaded.
Loading citypersons…
citypersons loaded.
Loading mot17…
mot17 loaded.
Loading prw…
prw loaded.
Loading eth…
eth loaded.
[Start training fairmot_dla34]

epoch: 1 step: 1, loss is 55.63408279418945
epoch: 1 step: 2, loss is 135.11972045898438
epoch: 1 step: 3, loss is 68.24212646484375
epoch: 1 step: 4, loss is 99.36243438720703
epoch: 1 step: 5, loss is 77.19902801513672
epoch: 1 step: 6, loss is 61.66413116455078
epoch: 1 step: 7, loss is 51.42118453979492
epoch: 1 step: 8, loss is 39.20842742919922
epoch: 1 step: 9, loss is 62.66177749633789
epoch: 1 step: 10, loss is 68.33760070800781
epoch: 1 step: 11, loss is 64.51337432861328
epoch: 1 step: 12, loss is 66.97705841064453
epoch: 1 step: 13, loss is 85.49947357177734
epoch: 1 step: 14, loss is 44.306880950927734
epoch: 1 step: 15, loss is 44.892669677734375
epoch: 1 step: 16, loss is 32.04488754272461
epoch: 1 step: 17, loss is 53.46352005004883
epoch: 1 step: 18, loss is 33.23221206665039
epoch: 1 step: 19, loss is 94.40752410888672
epoch: 1 step: 20, loss is 42.513668060302734
epoch: 1 step: 21, loss is 53.83503341674805
epoch: 1 step: 22, loss is 71.90801239013672
epoch: 1 step: 23, loss is 83.15853881835938
epoch: 1 step: 24, loss is 36.20119857788086
epoch: 1 step: 25, loss is 35.76030731201172
epoch: 1 step: 26, loss is 48.240718841552734
epoch: 1 step: 27, loss is 40.638771057128906
epoch: 1 step: 28, loss is 30.805248260498047
epoch: 1 step: 29, loss is 60.74918746948242
epoch: 1 step: 30, loss is 55.86394500732422
epoch: 1 step: 31, loss is 39.79429626464844
epoch: 1 step: 32, loss is 36.09943771362305
epoch: 1 step: 33, loss is 41.27968215942383
epoch: 1 step: 34, loss is 43.07084274291992
epoch: 1 step: 35, loss is 32.99536895751953
epoch: 1 step: 36, loss is 52.2248649597168
epoch: 1 step: 37, loss is 35.28694534301758
epoch: 1 step: 38, loss is 29.907625198364258
epoch: 1 step: 39, loss is 44.55171585083008
epoch: 1 step: 40, loss is 36.937530517578125
epoch: 1 step: 41, loss is 40.78886413574219
epoch: 1 step: 42, loss is 44.26179122924805
epoch: 1 step: 43, loss is 54.04239273071289
epoch: 1 step: 44, loss is 66.3919677734375
epoch: 1 step: 45, loss is 37.05625534057617
epoch: 1 step: 46, loss is 57.69034194946289
epoch: 1 step: 47, loss is 37.09925842285156
epoch: 1 step: 48, loss is 41.87119674682617
epoch: 1 step: 49, loss is 40.871116638183594
epoch: 1 step: 50, loss is 51.75830078125
epoch: 1 step: 51, loss is 40.27484130859375
epoch: 1 step: 52, loss is 32.51845932006836
epoch: 1 step: 53, loss is 65.54149627685547
epoch: 1 step: 54, loss is 54.571102142333984
epoch: 1 step: 55, loss is 48.70039749145508
epoch: 1 step: 56, loss is 40.226768493652344
epoch: 1 step: 57, loss is 40.18015670776367
epoch: 1 step: 58, loss is 50.56803512573242
epoch: 1 step: 59, loss is 45.177005767822266
epoch: 1 step: 60, loss is 52.70391082763672
epoch: 1 step: 61, loss is 44.88543701171875
epoch: 1 step: 62, loss is 33.11354446411133
epoch: 1 step: 63, loss is 37.11306381225586
epoch: 1 step: 64, loss is 38.995479583740234
epoch: 1 step: 65, loss is 47.20582580566406
epoch: 1 step: 66, loss is 33.67197036743164
epoch: 1 step: 67, loss is 30.655174255371094
epoch: 1 step: 68, loss is 38.68879699707031
epoch: 1 step: 69, loss is 64.86235046386719
epoch: 1 step: 70, loss is 64.23455810546875
epoch: 1 step: 71, loss is 28.83365821838379
epoch: 1 step: 72, loss is 36.305667877197266
epoch: 1 step: 73, loss is 32.7441520690918
epoch: 1 step: 74, loss is 28.804264068603516
epoch: 1 step: 75, loss is 27.86435890197754
epoch: 1 step: 76, loss is 41.876983642578125
epoch: 1 step: 77, loss is 31.075077056884766
epoch: 1 step: 78, loss is 33.951351165771484
epoch: 1 step: 79, loss is 27.698165893554688
epoch: 1 step: 80, loss is 26.100616455078125
epoch: 1 step: 81, loss is 42.59402847290039
epoch: 1 step: 82, loss is 27.64974594116211
epoch: 1 step: 83, loss is 33.34096145629883
epoch: 1 step: 84, loss is 48.80719757080078
epoch: 1 step: 85, loss is 36.00349807739258
epoch: 1 step: 86, loss is 49.37395095825195
epoch: 1 step: 87, loss is 50.31093215942383
epoch: 1 step: 88, loss is 38.51315689086914
epoch: 1 step: 89, loss is 30.891132354736328
epoch: 1 step: 90, loss is 30.514766693115234
epoch: 1 step: 91, loss is 47.496952056884766
epoch: 1 step: 92, loss is 38.22492599487305
epoch: 1 step: 93, loss is 30.233394622802734
epoch: 1 step: 94, loss is 36.349761962890625
epoch: 1 step: 95, loss is 37.37440490722656
epoch: 1 step: 96, loss is 30.823909759521484
epoch: 1 step: 97, loss is 51.04092025756836
epoch: 1 step: 98, loss is 30.568363189697266
epoch: 1 step: 99, loss is 49.097557067871094
epoch: 1 step: 100, loss is 53.643253326416016

7.评估

在评估过程中,使用的数据集为MOT17数据集。估计结果可在"./output"文件中查找。

“”“MindSpore Vision Video tracking eval script.”“”

import os
import numpy as np
from mindspore import context, load_checkpoint, load_param_into_net

from msvideo.utils.config import parse_args, Config
from msvideo.models import build_model
from msvideo.utils.post_process.infer_net import InferNet, WithInferNetCell
from msvideo.utils.post_process.eval_utils.eval_seq import eval_seq
from msvideo.utils.post_process.eval_utils.load_images import LoadImages
from msvideo.utils.post_process.tracker.multitracker import JDETracker
from msvideo.utils.post_process.tracking_utils.evaluation import Evaluator

import motmetrics as mm

def eval_tracking(pargs):
# set config context
config = Config(pargs.config)
context.set_context(**config.context)

# set network
network = build_model(config.model)

# load pretrain model
param_dict = load_checkpoint(config.eval.ckpt_path)
load_param_into_net(network, param_dict)

# init the whole Model
infer_net = InferNet()
net = WithInferNetCell(network, infer_net)
net.set_train(False)

# Calculate eval results.
accs = []
timer_avgs, timer_calls = [], []
if config.eval.data_seqs == "mot16":
    eval_seqs = ['MOT16-02',
                 'MOT16-04',
                 'MOT16-05',
                 'MOT16-09',
                 'MOT16-10',
                 'MOT16-11',
                 'MOT16-13']
elif config.eval.data_seqs == "mot17":
    eval_seqs = ['MOT17-02-SDP',
                 'MOT17-04-SDP',
                 'MOT17-05-SDP',
                 'MOT17-09-SDP',
                 'MOT17-10-SDP',
                 'MOT17-13-SDP']
elif config.eval.data_seqs == "mot20":
    eval_seqs = ['MOT20-01',
                 'MOT20-02',
                 'MOT20-03',
                 'MOT20-05']
if not os.path.exists(config.eval.output_dir):
    os.mkdir(config.eval.output_dir)
for seq in eval_seqs:
    dataloader = LoadImages(os.path.join(config.eval.data_root, seq, 'img1'),
                            0,
                            (1088, 608))
    tracker = JDETracker(config.eval.conf_thres,
                         config.eval.track_buffer,
                         config.eval.max_objs,
                         config.eval.num_classes,
                         frame_rate=30)
    seq_save_dir = os.path.join(config.eval.output_dir, seq)
    result_filename = os.path.join(seq_save_dir, 'eval_result.txt')
    _, ta, tc = eval_seq(net,
                         dataloader,
                         tracker,
                         config.eval.down_ratio,
                         config.eval.min_box_area,
                         config.eval.data_type,
                         result_filename,
                         start_id=0,
                         save_dir=seq_save_dir,
                         show_image=False)
    evaluator = Evaluator(config.eval.data_root, seq, config.eval.data_type)
    accs.append(evaluator.eval_file(result_filename))
timer_avgs = np.asarray(ta)
timer_calls = np.asarray(tc)
all_time = np.dot(timer_avgs, timer_calls)
avg_time = all_time / np.sum(timer_calls)
print('Time elapsed: {:.2f} seconds, FPS: {:.2f}'.format(all_time, 1.0 / avg_time))

metrics = mm.metrics.motchallenge_metrics
mh = mm.metrics.create()
summary = Evaluator.get_summary(accs, eval_seqs, metrics)
strsummary = mm.io.render_summary(
    summary,
    formatters=mh.formatters,
    namemap=mm.io.motchallenge_metric_names
)
print(strsummary)

if name == ‘main’:
args = parse_args()
eval_tracking(args)

Processing frame 0 (100000.00 fps)
Processing frame 20 (1.23 fps)
Processing frame 40 (2.11 fps)
Processing frame 60 (2.78 fps)
Processing frame 80 (3.32 fps)
Processing frame 100 (3.76 fps)
Processing frame 120 (4.10 fps)
Processing frame 140 (4.40 fps)
Processing frame 160 (4.66 fps)
Processing frame 180 (4.88 fps)
Processing frame 200 (5.06 fps)
Processing frame 220 (5.23 fps)
Processing frame 240 (5.38 fps)
Processing frame 260 (5.52 fps)
Processing frame 280 (5.63 fps)
Processing frame 300 (5.72 fps)
Processing frame 320 (5.81 fps)
Processing frame 340 (5.90 fps)
Processing frame 360 (5.98 fps)
Processing frame 380 (6.05 fps)
Processing frame 400 (6.12 fps)
Processing frame 420 (6.17 fps)
Processing frame 440 (6.22 fps)
Processing frame 460 (6.28 fps)
Processing frame 480 (6.32 fps)
Processing frame 500 (6.36 fps)
Processing frame 520 (6.41 fps)
Processing frame 540 (6.45 fps)
Processing frame 560 (6.49 fps)
Processing frame 580 (6.54 fps)
save results to %s output/MOT17-02-SDP/eval_result.txt
Processing frame 0 (100000.00 fps)
Processing frame 20 (7.58 fps)
Processing frame 40 (7.59 fps)
Processing frame 60 (7.57 fps)
Processing frame 80 (7.63 fps)
Processing frame 100 (7.64 fps)
Processing frame 120 (7.59 fps)
Processing frame 140 (7.59 fps)
Processing frame 160 (7.59 fps)
Processing frame 180 (7.59 fps)
Processing frame 200 (7.56 fps)
Processing frame 220 (7.54 fps)
Processing frame 240 (7.52 fps)
Processing frame 260 (7.52 fps)
Processing frame 280 (7.49 fps)
Processing frame 300 (7.46 fps)
Processing frame 320 (7.47 fps)
Processing frame 340 (7.47 fps)
Processing frame 360 (7.44 fps)
Processing frame 380 (7.44 fps)
Processing frame 400 (7.45 fps)
Processing frame 420 (7.43 fps)
Processing frame 440 (7.41 fps)
Processing frame 460 (7.39 fps)
Processing frame 480 (7.37 fps)
Processing frame 500 (7.37 fps)
Processing frame 520 (7.37 fps)
Processing frame 540 (7.36 fps)
Processing frame 560 (7.36 fps)
Processing frame 580 (7.36 fps)
Processing frame 600 (7.36 fps)
Processing frame 620 (7.36 fps)
Processing frame 640 (7.37 fps)
Processing frame 660 (7.37 fps)
Processing frame 680 (7.37 fps)
Processing frame 700 (7.38 fps)
Processing frame 720 (7.38 fps)
Processing frame 740 (7.37 fps)
Processing frame 760 (7.37 fps)
Processing frame 780 (7.38 fps)
Processing frame 800 (7.38 fps)
Processing frame 820 (7.38 fps)
Processing frame 840 (7.38 fps)
Processing frame 860 (7.38 fps)
Processing frame 880 (7.37 fps)
Processing frame 900 (7.36 fps)
Processing frame 920 (7.36 fps)
Processing frame 940 (7.36 fps)
Processing frame 960 (7.36 fps)
Processing frame 980 (7.35 fps)
Processing frame 1000 (7.31 fps)
Processing frame 1020 (7.30 fps)
Processing frame 1040 (7.30 fps)
save results to %s output/MOT17-04-SDP/eval_result.txt
Processing frame 0 (100000.00 fps)
Processing frame 20 (8.27 fps)
Processing frame 40 (8.21 fps)
Processing frame 60 (8.22 fps)
Processing frame 80 (8.23 fps)
Processing frame 100 (8.23 fps)
Processing frame 120 (8.22 fps)
Processing frame 140 (8.20 fps)
Processing frame 160 (8.20 fps)
Processing frame 180 (8.22 fps)
Processing frame 200 (8.21 fps)
Processing frame 220 (8.21 fps)
Processing frame 240 (8.22 fps)
Processing frame 260 (8.20 fps)
Processing frame 280 (8.19 fps)
Processing frame 300 (8.20 fps)
Processing frame 320 (8.20 fps)
Processing frame 340 (8.19 fps)
Processing frame 360 (8.19 fps)
Processing frame 380 (8.19 fps)
Processing frame 400 (8.17 fps)
Processing frame 420 (8.18 fps)
Processing frame 440 (8.18 fps)
Processing frame 460 (8.18 fps)
Processing frame 480 (8.18 fps)
Processing frame 500 (8.19 fps)
Processing frame 520 (8.19 fps)
Processing frame 540 (8.19 fps)
Processing frame 560 (8.19 fps)
Processing frame 580 (8.20 fps)
Processing frame 600 (8.20 fps)
Processing frame 620 (8.20 fps)
Processing frame 640 (8.20 fps)
Processing frame 660 (8.20 fps)
Processing frame 680 (8.21 fps)
Processing frame 700 (8.21 fps)
Processing frame 720 (8.21 fps)
Processing frame 740 (8.21 fps)
Processing frame 760 (8.21 fps)
Processing frame 780 (8.21 fps)
Processing frame 800 (8.20 fps)
Processing frame 820 (8.20 fps)
save results to %s output/MOT17-05-SDP/eval_result.txt
Processing frame 0 (100000.00 fps)
Processing frame 20 (8.16 fps)
Processing frame 40 (8.11 fps)
Processing frame 60 (8.06 fps)
Processing frame 80 (8.04 fps)
Processing frame 100 (8.07 fps)
Processing frame 120 (8.06 fps)
Processing frame 140 (8.06 fps)
Processing frame 160 (8.07 fps)
Processing frame 180 (8.07 fps)
Processing frame 200 (8.07 fps)
Processing frame 220 (8.08 fps)
Processing frame 240 (8.08 fps)
Processing frame 260 (8.07 fps)
Processing frame 280 (8.07 fps)
Processing frame 300 (8.07 fps)
Processing frame 320 (8.07 fps)
Processing frame 340 (8.07 fps)
Processing frame 360 (8.08 fps)
Processing frame 380 (8.08 fps)
Processing frame 400 (8.07 fps)
Processing frame 420 (8.07 fps)
Processing frame 440 (8.07 fps)
Processing frame 460 (8.07 fps)
Processing frame 480 (8.07 fps)
Processing frame 500 (8.07 fps)
Processing frame 520 (8.07 fps)
save results to %s output/MOT17-09-SDP/eval_result.txt
Processing frame 0 (100000.00 fps)
Processing frame 20 (7.91 fps)
Processing frame 40 (7.76 fps)
Processing frame 60 (7.64 fps)
Processing frame 80 (7.67 fps)
Processing frame 100 (7.67 fps)
Processing frame 120 (7.71 fps)
Processing frame 140 (7.72 fps)
Processing frame 160 (7.75 fps)
Processing frame 180 (7.77 fps)
Processing frame 200 (7.73 fps)
Processing frame 220 (7.74 fps)
Processing frame 240 (7.76 fps)
Processing frame 260 (7.77 fps)
Processing frame 280 (7.74 fps)
Processing frame 300 (7.73 fps)
Processing frame 320 (7.69 fps)
Processing frame 340 (7.68 fps)
Processing frame 360 (7.68 fps)
Processing frame 380 (7.68 fps)
Processing frame 400 (7.69 fps)
Processing frame 420 (7.69 fps)
Processing frame 440 (7.70 fps)
Processing frame 460 (7.71 fps)
Processing frame 480 (7.72 fps)
Processing frame 500 (7.73 fps)
Processing frame 520 (7.74 fps)
Processing frame 540 (7.75 fps)
Processing frame 560 (7.75 fps)
Processing frame 580 (7.76 fps)
Processing frame 600 (7.76 fps)
Processing frame 620 (7.77 fps)
Processing frame 640 (7.77 fps)
save results to %s output/MOT17-10-SDP/eval_result.txt
Processing frame 0 (100000.00 fps)
Processing frame 20 (8.05 fps)
Processing frame 40 (7.99 fps)
Processing frame 60 (7.91 fps)
Processing frame 80 (7.93 fps)
Processing frame 100 (7.91 fps)
Processing frame 120 (7.86 fps)
Processing frame 140 (7.84 fps)
Processing frame 160 (7.84 fps)
Processing frame 180 (7.83 fps)
Processing frame 200 (7.84 fps)
Processing frame 220 (7.87 fps)
Processing frame 240 (7.87 fps)
Processing frame 260 (7.89 fps)
Processing frame 280 (7.90 fps)
Processing frame 300 (7.90 fps)
Processing frame 320 (7.91 fps)
Processing frame 340 (7.92 fps)
Processing frame 360 (7.92 fps)
Processing frame 380 (7.92 fps)
Processing frame 400 (7.93 fps)
Processing frame 420 (7.94 fps)
Processing frame 440 (7.95 fps)
Processing frame 460 (7.96 fps)
Processing frame 480 (7.98 fps)
Processing frame 500 (7.99 fps)
Processing frame 520 (8.00 fps)
Processing frame 540 (8.01 fps)
Processing frame 560 (8.01 fps)
Processing frame 580 (8.02 fps)
Processing frame 600 (8.04 fps)
Processing frame 620 (8.05 fps)
Processing frame 640 (8.06 fps)
Processing frame 660 (8.07 fps)
Processing frame 680 (8.08 fps)
Processing frame 700 (8.09 fps)
Processing frame 720 (8.10 fps)
Processing frame 740 (8.10 fps)
save results to %s output/MOT17-13-SDP/eval_result.txt
Time elapsed: 92.48 seconds, FPS: 8.11

8.参考内容

论文:https://arxiv.org/pdf/2004.01888v2.pdf

博客:
https://blog.csdn.net/weixin_42398658/article/details/110873083
sing frame 360 (7.92 fps)
Processing frame 380 (7.92 fps)
Processing frame 400 (7.93 fps)
Processing frame 420 (7.94 fps)
Processing frame 440 (7.95 fps)
Processing frame 460 (7.96 fps)
Processing frame 480 (7.98 fps)
Processing frame 500 (7.99 fps)
Processing frame 520 (8.00 fps)
Processing frame 540 (8.01 fps)
Processing frame 560 (8.01 fps)
Processing frame 580 (8.02 fps)
Processing frame 600 (8.04 fps)
Processing frame 620 (8.05 fps)
Processing frame 640 (8.06 fps)
Processing frame 660 (8.07 fps)
Processing frame 680 (8.08 fps)
Processing frame 700 (8.09 fps)
Processing frame 720 (8.10 fps)
Processing frame 740 (8.10 fps)
save results to %s output/MOT17-13-SDP/eval_result.txt
Time elapsed: 92.48 seconds, FPS: 8.11

8.参考内容

论文:https://arxiv.org/pdf/2004.01888v2.pdf

博客:
https://blog.csdn.net/weixin_42398658/article/details/110873083
https://blog.csdn.net/qq_41204464/article/details/122893061

代码仓库地址如下:Yanlq/zjut_mindvideo
https://github.com/ZJUT-ERCISS/fairmot_mindspore
tutorials/tracking/fairmot/fairmot_example.ipynb · Yanlq/zjut_mindvideo - Gitee.com

  • 2
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值