单目标跟踪SiamMask:特定目标车辆追踪 part2

日萌社

 

人工智能AI:Keras PyTorch MXNet TensorFlow PaddlePaddle 深度学习实战(不定时更新)


CNN:RCNN、SPPNet、Fast RCNN、Faster RCNN、YOLO V1 V2 V3、SSD、FCN、SegNet、U-Net、DeepLab V1 V2 V3、Mask RCNN

单目标跟踪SiamMask:特定目标车辆追踪 part1

单目标跟踪SiamMask:特定目标车辆追踪 part2

单目标跟踪 Siamese系列网络:SiamFC、SiamRPN、one-shot跟踪、one-shotting单样本学习、DaSiamRPN、SiamRPN++、SiamMask

单目标跟踪:跟踪效果

单目标跟踪:数据集处理

单目标跟踪:模型搭建

单目标跟踪:模型训练

单目标跟踪:模型测试


SiamMask_master\tools

config.json

{
    "network": {
        "arch": "Custom"
    },
    "hp": {
        "instance_size": 255,
        "base_size": 8,
        "out_size": 127,
        "seg_thr": 0.35,
        "penalty_k": 0.04,
        "window_influence": 0.4,
        "lr": 1.0
    },
    "anchors": {
        "stride": 8,
        "ratios": [0.33, 0.5, 1, 2, 3],
        "scales": [8],
        "round_dight": 0
    }
}

demo.py

# --------------------------------------------------------
# demo.py
# --------------------------------------------------------
import io
import json
import flask
import torch
import torch.nn.functional as F
from PIL import Image
from torchvision import transforms as T
from tools.test import *
import glob

# 1 创建解析对象
parser = argparse.ArgumentParser(description='PyTorch Tracking Demo')

# 2 添加参数
# 2.1 resume:梗概
parser.add_argument('--resume', default='SiamMask.pth', type=str,
                    metavar='PATH',help='path to latest checkpoint (default: none)')
# 2.2 config配置
parser.add_argument('--config', dest='config', default='config.json',
                    help='hyper-parameter of SiamMask in json format')
# 2.3 处理的图像的序列
parser.add_argument('--base_path', default='../data/car', help='datasets')
# 2.4 硬件信息
parser.add_argument('--cpu', action='store_true', help='cpu mode')
# 3 解析参数
args = parser.parse_args()

writer = None

if __name__ == '__main__':
    # 1. 设置设备信息 Setup device
    # 有GPU时选择GPU,否则使用CPU
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # 默认优化运行效率
    torch.backends.cudnn.benchmark = True

    # 2. 模型设置 Setup Model
    # 2.1 将命令行参数解析出来
    cfg = load_config(args)

    # 2.2 custom是构建的网络,否则引用model中的网络结构
    from custom import Custom
    siammask = Custom(anchors=cfg['anchors'])
    # 2.3 判断是否存在模型的权重文件
    if args.resume:
        assert isfile(args.resume), 'Please download {} first.'.format(args.resume)
        siammask = load_pretrain(siammask, args.resume)
    # 在运行推断前,需要调用 model.eval() 函数,以将 dropout 层 和 batch normalization 层设置为评估模式(非训练模式).
    # to(device)将张量复制到GPU上,之后的计算将在GPU上运行
    siammask.eval().to(device)

    # 3. 读取图片序列 Parse Image file
    img_files = sorted(glob.glob(join(args.base_path, '*.jp*')))
    ims = [cv2.imread(imf) for imf in img_files]

    # 4. 选择目标区域 Select ROI
    cv2.namedWindow("SiamMask", cv2.WND_PROP_FULLSCREEN)
    # cv2.setWindowProperty("SiamMask", cv2.WND_PROP_FULLSCREEN, cv2.WINDOW_FULLSCREEN)
    # 5. 将目标框转换为矩形左上角坐标,宽 高的形式
    try:
        init_rect = cv2.selectROI('SiamMask', ims[0], False, False)
        x, y, w, h = init_rect
        print(x,y,w,h)
    except:
        exit()

    toc = 0
    # 6. 遍历所有的图片
    for f, im in enumerate(ims):
        # 用于记时:初始的记时周期
        tic = cv2.getTickCount()
        # 初始化
        if f == 0:  # init
            # 目标位置
            target_pos = np.array([x + w / 2, y + h / 2])
            # 目标大小
            target_sz = np.array([w, h])
            # 目标追踪初始化
            state = siamese_init(im, target_pos, target_sz, siammask, cfg['hp'], device=device)  # init tracker
        # 目标跟踪
        elif f > 0:  # tracking
            # 目标追踪,进行state的更新
            state = siamese_track(state, im, mask_enable=True, refine_enable=True, device=device)  # track
            # 确定目标位置
            location = state['ploygon'].flatten()
            # 生成目标分割的掩码
            mask = state['mask'] > state['p'].seg_thr
            # 将掩码信息显示在图像上
            im[:, :, 2] = (mask > 0) * 255 + (mask == 0) * im[:, :, 2]
            # 绘制跟踪目标的位置信息
            cv2.polylines(im, [np.int0(location).reshape((-1, 1, 2))], True, (0, 255, 0), 3)
            cv2.imshow('SiamMask', im)
            key = cv2.waitKey(1)
            if key > 0:
                break
        # 用于记时,获取每一张图片最终的记时周期,并进行统计
        toc += cv2.getTickCount() - tic
    # 获取全部图片的处理时间
    toc /= cv2.getTickFrequency()
    # 计算fps
    fps = f / toc
    print('SiamMask Time: {:02.1f}s Speed: {:3.1f}fps (with visulization!)'.format(toc, fps))

resnet.py

import torch.nn as nn
import torch
from torch.autograd import Variable
import math
import torch.utils.model_zoo as model_zoo
from models.features import Features

__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
           'resnet152']

# 已进行预训练的resnet模型
model_urls = {
    'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
    'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
    'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
    'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
    'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
}


def conv3x3(in_planes, out_planes, stride=1):
    """
    "3x3 convolution with padding"
    3*3卷积
    :param in_planes: 输入通道数
    :param out_planes: 输出通道数
    :param stride: 步长
    :return:
    """
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)


class BasicBlock(nn.Module):
    """
    基础的瓶颈模块,由两个叠加的3*3卷积组成,用于res18和res34
    """
    # 对输出深度的倍乘
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        # 3*3卷积 BN层 Relu激活
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)
        # shortcut
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        # 上一层的输出
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        # shortcut存在时
        if self.downsample is not None:
            # 将上一层的输出x输入downsample,将结果赋给residual
            # 目的就是为了应对上下层输出输入深度一致
            residual = self.downsample(x)
        # 将BN层结果与shortcut相加
        out += residual
        # relu激活
        out = self.relu(out)
        return out


class Bottleneck(Features):
    """
    瓶颈模块,有1*1 3*3 1*1三个卷积层构成,分别用来降低维度,卷积处理和升高维度
    继承在feature,用于特征提取
    """
    # 将输入深度进行倍乘(若输入深度为64,那么扩张4倍后就变为了256)
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1):
        super(Bottleneck, self).__init__()
        # 1*1卷积
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
        # BN
        self.bn1 = nn.BatchNorm2d(planes)
        # padding = (2 - stride) + (dilation // 2 - 1)
        padding = 2 - stride
        assert stride==1 or dilation==1, "stride and dilation must have one equals to zero at least"
        if dilation > 1:
            padding = dilation
        # 3*3 卷积
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
                                padding=padding, bias=False, dilation=dilation)
        self.bn2 = nn.BatchNorm2d(planes)
        # 1*1 卷积
        self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes * 4)
        self.relu = nn.ReLU(inplace=True)
        # shortcut
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        if out.size() != residual.size():
            print(out.size(), residual.size())
        out += residual

        out = self.relu(out)

        return out



class Bottleneck_nop(nn.Module):
    """
    官网原始的瓶颈块
    """
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(Bottleneck_nop, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
                               padding=0, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes * 4)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        s = residual.size(3)
        residual = residual[:, :, 1:s-1, 1:s-1]

        out += residual
        out = self.relu(out)

        return out


class ResNet(nn.Module):
    """
    ResNet主体部分实现
    """
    def __init__(self, block, layers, layer4=False, layer3=False):
        """
        主体实现
        :param block: 基础块:BasicBlock或者BottleNeck
        :param layers: 每个大的layer中的block个数
        :param layer4:
        :param layer3:是否加入layer3和layer4
        """
        # 输入深度
        self.inplanes = 64
        super(ResNet, self).__init__()
        # 卷积
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=0,
                               bias=False)
        # BN
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        # 池化
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        # 将block块添加到layer中
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2) # 31x31, 15x15

        self.feature_size = 128 * block.expansion
        # 添加layer3
        if layer3:
            self.layer3 = self._make_layer(block, 256, layers[2], stride=1, dilation=2) # 15x15, 7x7
            self.feature_size = (256 + 128) * block.expansion
        else:
            self.layer3 = lambda x:x # identity
        # 添加layer4
        if layer4:
            self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=4) # 7x7, 3x3
            self.feature_size = 512 * block.expansion
        else:
            self.layer4 = lambda x:x  # identity
        # 参数初始化
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def _make_layer(self, block, planes, blocks, stride=1, dilation=1):
        """
        :param block: basicbottle或者是bottleneck
        :param planes:通道数
        :param blocks:添加block的个数
        :param stride:
        :param dilation:
        :return:
        """
        downsample = None
        dd = dilation
        # shortcut的设置
        if stride != 1 or self.inplanes != planes * block.expansion:
            if stride == 1 and dilation == 1:
                downsample = nn.Sequential(
                    nn.Conv2d(self.inplanes, planes * block.expansion,
                              kernel_size=1, stride=stride, bias=False),
                    nn.BatchNorm2d(planes * block.expansion),
                )
            else:
                if dilation > 1:
                    dd = dilation // 2
                    padding = dd
                else:
                    dd = 1
                    padding = 0
                downsample = nn.Sequential(
                    nn.Conv2d(self.inplanes, planes * block.expansion,
                              kernel_size=3, stride=stride, bias=False,
                              padding=padding, dilation=dd),
                    nn.BatchNorm2d(planes * block.expansion),
                )

        layers = []
        # layers.append(block(self.inplanes, planes, stride, downsample, dilation=dilation))
        layers.append(block(self.inplanes, planes, stride, downsample, dilation=dd))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes, dilation=dilation))
        # 构建网络
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        p0 = self.relu(x)
        x = self.maxpool(p0)

        p1 = self.layer1(x)
        p2 = self.layer2(p1)
        p3 = self.layer3(p2)

        return p0, p1, p2, p3


class ResAdjust(nn.Module):
    """
    对模块进行adjust,未使用
    """

    def __init__(self,
            block=Bottleneck,
            out_channels=256,
            adjust_number=1,
            fuse_layers=[2,3,4]):
        super(ResAdjust, self).__init__()
        self.fuse_layers = set(fuse_layers)

        if 2 in self.fuse_layers:
            self.layer2 = self._make_layer(block, 128, 1, out_channels, adjust_number)
        if 3 in self.fuse_layers:
            self.layer3 = self._make_layer(block, 256, 2, out_channels, adjust_number)
        if 4 in self.fuse_layers:
            self.layer4 = self._make_layer(block, 512, 4, out_channels, adjust_number)

        self.feature_size = out_channels * len(self.fuse_layers)


    def _make_layer(self, block, plances, dilation, out, number=1):

        layers = []

        for _ in range(number):
            layer = block(plances * block.expansion, plances, dilation=dilation)
            layers.append(layer)

        downsample = nn.Sequential(
                nn.Conv2d(plances * block.expansion, out, kernel_size=3, padding=1, bias=False),
                nn.BatchNorm2d(out)
                )
        layers.append(downsample)

        return nn.Sequential(*layers)

    def forward(self, p2, p3, p4):

        outputs = []

        if 2 in self.fuse_layers:
            outputs.append(self.layer2(p2))
        if 3 in self.fuse_layers:
            outputs.append(self.layer3(p3))
        if 4 in self.fuse_layers:
            outputs.append(self.layer4(p4))
        # return torch.cat(outputs, 1)
        return outputs


def resnet18(pretrained=False, **kwargs):
    """Constructs a ResNet-18 model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
    return model


def resnet34(pretrained=False, **kwargs):
    """Constructs a ResNet-34 model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['resnet34']))
    return model


def resnet50(pretrained=False, **kwargs):
    """Constructs a ResNet-50 model.
    重构resnet50
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
    return model


def resnet101(pretrained=False, **kwargs):
    """Constructs a ResNet-101 model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['resnet101']))
    return model


def resnet152(pretrained=False, **kwargs):
    """Constructs a ResNet-152 model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['resnet152']))
    return model

if __name__ == '__main__':
    net = resnet50()
    print(net)
    # net = net.cuda()
    #
    # var = torch.FloatTensor(1,3,127,127).cuda()
    # var = Variable(var)
    #
    # net(var)
    # print('*************')
    # var = torch.FloatTensor(1,3,255,255).cuda()
    # var = Variable(var)

    # net(var)

    var = torch.FloatTensor(1,3,127,127)
    var = Variable(var)
    print(var)
    out = net(var)
    print(out)

test.py

# --------------------------------------------------------
# SiamMask
# Licensed under The MIT License
# Written by Qiang Wang (wangqiang2015 at ia.ac.cn)
# --------------------------------------------------------
from __future__ import division
import argparse
import logging
import numpy as np
import cv2
from PIL import Image
from os import makedirs
from os.path import join, isdir, isfile

from utils.log_helper import init_log, add_file_handler
from utils.load_helper import load_pretrain
from utils.bbox_helper import get_axis_aligned_bbox, cxy_wh_2_rect
from utils.benchmark_helper import load_dataset, dataset_zoo

import torch
from torch.autograd import Variable
import torch.nn.functional as F

from utils.anchors import Anchors
from utils.tracker_config import TrackerConfig

from utils.config_helper import load_config
# from utils.pyvotkit.region import vot_overlap, vot_float2str
# 在目标分割中将某一像素作为目标的阈值
thrs = np.arange(0.3, 0.5, 0.05)
# 参数信息配置
parser = argparse.ArgumentParser(description='Test SiamMask')
parser.add_argument('--arch', dest='arch', default='', choices=['Custom',],
                    help='architecture of pretrained model')
parser.add_argument('--config', dest='config', required=True, help='hyper-parameter for SiamMask')
parser.add_argument('--resume', default='', type=str, required=True,
                    metavar='PATH', help='path to latest checkpoint (default: none)')
parser.add_argument('--mask', action='store_true', help='whether use mask output')
parser.add_argument('--refine', action='store_true', help='whether use mask refine output')
parser.add_argument('--dataset', dest='dataset', default='VOT2018', choices=dataset_zoo,
                    help='datasets')
parser.add_argument('-l', '--log', default="log_test.txt", type=str, help='log file')
parser.add_argument('-v', '--visualization', dest='visualization', action='store_true',
                    help='whether visualize result')
parser.add_argument('--save_mask', action='store_true', help='whether use save mask for davis')
parser.add_argument('--gt', action='store_true', help='whether use gt rect for davis (Oracle)')
parser.add_argument('--video', default='', type=str, help='test special video')
parser.add_argument('--cpu', action='store_true', help='cpu mode')
parser.add_argument('--debug', action='store_true', help='debug mode')


def to_torch(ndarray):
    '''
    将数据转换为TorcH tensor
    :param ndarray: ndarray
    :return: torch中的tensor
    '''
    if type(ndarray).__module__ == 'numpy':
        return torch.from_numpy(ndarray)
    elif not torch.is_tensor(ndarray):
        raise ValueError("Cannot convert {} to torch tensor"
                         .format(type(ndarray)))
    return ndarray


def im_to_torch(img):
    '''
    将图像转换为torch中的tensor
    :param img: 输入图像
    :return: 输出张量
    '''
    img = np.transpose(img, (2, 0, 1))  # C*H*W
    img = to_torch(img).float()
    return img


def get_subwindow_tracking(im, pos, model_sz, original_sz, avg_chans, out_mode='torch'):
    """
    获取跟踪目标的信息(图像窗口)
    :param im:跟踪的模板图像
    :param pos:目标位置
    :param model_sz:模型要求输入的目标尺寸
    :param original_sz: 扩展后的目标尺寸
    :param avg_chans:图像的平均值
    :param out_mode: 输出模式
    :return:
    """
    if isinstance(pos, float):
        # 目标中心点坐标
        pos = [pos, pos]
    # 目标的尺寸
    sz = original_sz
    # 图像尺寸
    im_sz = im.shape
    # 扩展背景后边界到中心的距离
    c = (original_sz + 1) / 2
    # 判断目标是否超出图像边界,若超出边界则对图像进行填充
    context_xmin = round(pos[0] - c)
    context_xmax = context_xmin + sz - 1
    context_ymin = round(pos[1] - c)
    context_ymax = context_ymin + sz - 1
    left_pad = int(max(0., -context_xmin))
    top_pad = int(max(0., -context_ymin))
    right_pad = int(max(0., context_xmax - im_sz[1] + 1))
    bottom_pad = int(max(0., context_ymax - im_sz[0] + 1))
    # 图像填充使得图像的原点发生变化,计算填充后图像块的坐标
    context_xmin = context_xmin + left_pad
    context_xmax = context_xmax + left_pad
    context_ymin = context_ymin + top_pad
    context_ymax = context_ymax + top_pad

    # zzp: a more easy speed version
    r, c, k = im.shape
    # 若进行填充需对目标位置重新赋值
    if any([top_pad, bottom_pad, left_pad, right_pad]):
        # 生成与填充后图像同样大小的全零数组
        te_im = np.zeros((r + top_pad + bottom_pad, c + left_pad + right_pad, k), np.uint8)
        # 对原图像区域进行赋值
        te_im[top_pad:top_pad + r, left_pad:left_pad + c, :] = im
        # 将填充区域赋值为图像的均值
        if top_pad:
            te_im[0:top_pad, left_pad:left_pad + c, :] = avg_chans
        if bottom_pad:
            te_im[r + top_pad:, left_pad:left_pad + c, :] = avg_chans
        if left_pad:
            te_im[:, 0:left_pad, :] = avg_chans
        if right_pad:
            te_im[:, c + left_pad:, :] = avg_chans
        # 根据填充结果修改目标的位置
        im_patch_original = te_im[int(context_ymin):int(context_ymax + 1), int(context_xmin):int(context_xmax + 1), :]
    else:
        im_patch_original = im[int(context_ymin):int(context_ymax + 1), int(context_xmin):int(context_xmax + 1), :]
    # 若跟踪目标块的尺寸与模型输入尺寸不同,则利用opencv修改图像尺寸
    if not np.array_equal(model_sz, original_sz):
        im_patch = cv2.resize(im_patch_original, (model_sz, model_sz))
    else:
        im_patch = im_patch_original
    # cv2.imshow('crop', im_patch)
    # cv2.waitKey(0)
    # 若输出模式是Torch,则将其通道调换,否则直接输出im_patch
    return im_to_torch(im_patch) if out_mode in 'torch' else im_patch


def generate_anchor(cfg, score_size):
    """
    生成锚点:anchor
    :param cfg: anchor的配置信息
    :param score_size:分类的评分结果
    :return:生成的anchor
    """
    # 初始化anchor
    anchors = Anchors(cfg)
    # 得到生成的anchors
    anchor = anchors.anchors
    # 得到每一个anchor的左上角和右下角坐标
    x1, y1, x2, y2 = anchor[:, 0], anchor[:, 1], anchor[:, 2], anchor[:, 3]
    # 将anchor转换为中心点坐标和宽高的形式
    anchor = np.stack([(x1+x2)*0.5, (y1+y2)*0.5, x2-x1, y2-y1], 1)
    # 获取生成anchor的范围
    total_stride = anchors.stride
    # 获取锚点的个数
    anchor_num = anchor.shape[0]
    # 将对锚点组进行广播,并设置其坐标。
    anchor = np.tile(anchor, score_size * score_size).reshape((-1, 4))
    # 加上ori偏移后,xx和yy以图像中心为原点
    ori = - (score_size // 2) * total_stride
    xx, yy = np.meshgrid([ori + total_stride * dx for dx in range(score_size)],
                         [ori + total_stride * dy for dy in range(score_size)])
    xx, yy = np.tile(xx.flatten(), (anchor_num, 1)).flatten(), \
             np.tile(yy.flatten(), (anchor_num, 1)).flatten()
    # 获取anchor
    anchor[:, 0], anchor[:, 1] = xx.astype(np.float32), yy.astype(np.float32)
    return anchor

def siamese_init(im, target_pos, target_sz, model, hp=None, device='cpu'):
    """
    初始化跟踪器,根据目标的信息构建state 字典
    :param im: 当前处理的图像
    :param target_pos: 目标的位置
    :param target_sz: 目标的尺寸
    :param model: 训练好的网络模型
    :param hp: 超参数
    :param device: 硬件信息
    :return: 跟踪器的state字典数据
    """

    # 初始化state字典
    state = dict()
    # 设置图像的宽高
    state['im_h'] = im.shape[0]
    state['im_w'] = im.shape[1]
    # 配置跟踪器的相关参数
    p = TrackerConfig()
    # 对参数进行更新
    p.update(hp, model.anchors)
    # 更新参数
    p.renew()
    # 获取网络模型
    net = model
    # 根据网络参数对跟踪器的参数进行更新,主要是anchors
    p.scales = model.anchors['scales']
    p.ratios = model.anchors['ratios']
    p.anchor_num = model.anchor_num
    # 生成锚点
    p.anchor = generate_anchor(model.anchors, p.score_size)
    # 图像的平均值
    avg_chans = np.mean(im, axis=(0, 1))
    # 根据设置的上下文比例,输入z 的宽高及尺寸
    wc_z = target_sz[0] + p.context_amount * sum(target_sz)
    hc_z = target_sz[1] + p.context_amount * sum(target_sz)
    s_z = round(np.sqrt(wc_z * hc_z))
    # 初始化跟踪目标 initialize the exemplar
    z_crop = get_subwindow_tracking(im, target_pos, p.exemplar_size, s_z, avg_chans)
    # 将其转换为Variable可在pythorch中进行反向传播
    z = Variable(z_crop.unsqueeze(0))
    # 专门处理模板
    net.template(z.to(device))
    # 设置使用的惩罚窗口
    if p.windowing == 'cosine':
        # 利用hanning窗的外积生成cosine窗口
        window = np.outer(np.hanning(p.score_size), np.hanning(p.score_size))
    elif p.windowing == 'uniform':
        window = np.ones((p.score_size, p.score_size))
    # 每一个anchor都有一个对应的惩罚窗口
    window = np.tile(window.flatten(), p.anchor_num)
    # 将信息更新到state字典中
    state['p'] = p
    state['net'] = net
    state['avg_chans'] = avg_chans
    state['window'] = window
    state['target_pos'] = target_pos
    state['target_sz'] = target_sz
    return state


def siamese_track(state, im, mask_enable=False, refine_enable=False, device='cpu', debug=False):
    """
    对目标进行跟踪
    :param state:目标状态
    :param im:跟踪的图像帧
    :param mask_enable:是否进行掩膜
    :param refine_enable:是否进行特征的融合
    :param device:硬件信息
    :param debug: 是否进行debug
    :return:跟踪目标的状态 state字典
    """
    # 获取目标状态
    p = state['p']
    net = state['net']
    avg_chans = state['avg_chans']
    window = state['window']
    target_pos = state['target_pos']
    target_sz = state['target_sz']
    # 包含周边信息的跟踪框的宽度,高度,尺寸
    wc_x = target_sz[1] + p.context_amount * sum(target_sz)
    hc_x = target_sz[0] + p.context_amount * sum(target_sz)
    s_x = np.sqrt(wc_x * hc_x)
    # 模板模型输入框尺寸与跟踪框的比例
    scale_x = p.exemplar_size / s_x
    # 使用与模板分支相同的比例得到检测区域
    d_search = (p.instance_size - p.exemplar_size) / 2
    pad = d_search / scale_x
    s_x = s_x + 2 * pad
    # 对检测框进行扩展,包含周边信息
    crop_box = [target_pos[0] - round(s_x) / 2, target_pos[1] - round(s_x) / 2, round(s_x), round(s_x)]
    # 若进行debug
    if debug:
        # 复制图片
        im_debug = im.copy()
        # 产生crop_box
        crop_box_int = np.int0(crop_box)
        # 将其绘制在图片上
        cv2.rectangle(im_debug, (crop_box_int[0], crop_box_int[1]),
                      (crop_box_int[0] + crop_box_int[2], crop_box_int[1] + crop_box_int[3]), (255, 0, 0), 2)
        # 图片展示
        cv2.imshow('search area', im_debug)
        cv2.waitKey(0)

    # extract scaled crops for search region x at previous target position
    # 将目标位置按比例转换为要跟踪的目标
    x_crop = Variable(get_subwindow_tracking(im, target_pos, p.instance_size, round(s_x), avg_chans).unsqueeze(0))
    # 调用网络进行目标跟踪
    if mask_enable:
        # 进行目标分割
        score, delta, mask = net.track_mask(x_crop.to(device))
    else:
        # 只进行目标追踪,不进行分割
        score, delta = net.track(x_crop.to(device))
    # 目标框回归结果(将其转成4*...的样式)
    delta = delta.permute(1, 2, 3, 0).contiguous().view(4, -1).data.cpu().numpy()
    # 目标分类结果(将其转成2*...的样式)
    score = F.softmax(score.permute(1, 2, 3, 0).contiguous().view(2, -1).permute(1, 0), dim=1).data[:,
            1].cpu().numpy()
    # 计算目标框的中心点坐标,delta[0],delta[1],以及宽delta[2]和高delta[3],这里变量不是很明确。
    delta[0, :] = delta[0, :] * p.anchor[:, 2] + p.anchor[:, 0]
    delta[1, :] = delta[1, :] * p.anchor[:, 3] + p.anchor[:, 1]
    delta[2, :] = np.exp(delta[2, :]) * p.anchor[:, 2]
    delta[3, :] = np.exp(delta[3, :]) * p.anchor[:, 3]

    def change(r):
        """
        将r和1/r逐位比较取最大值
        :param r:
        :return:
        """
        return np.maximum(r, 1. / r)

    def sz(w, h):
        """
        计算等效边长
        :param w: 宽
        :param h: 高
        :return: 等效边长
        """
        pad = (w + h) * 0.5
        sz2 = (w + pad) * (h + pad)
        return np.sqrt(sz2)

    def sz_wh(wh):
        """
        计算等效边长
        :param wh: 宽高的数组
        :return: 等效边长
        """
        pad = (wh[0] + wh[1]) * 0.5
        sz2 = (wh[0] + pad) * (wh[1] + pad)
        return np.sqrt(sz2)

    # 尺寸惩罚 size penalty
    target_sz_in_crop = target_sz*scale_x
    s_c = change(sz(delta[2, :], delta[3, :]) / (sz_wh(target_sz_in_crop)))  # scale penalty
    r_c = change((target_sz_in_crop[0] / target_sz_in_crop[1]) / (delta[2, :] / delta[3, :]))  # ratio penalty
    # p.penalty_k超参数
    penalty = np.exp(-(r_c * s_c - 1) * p.penalty_k)
    # 对分类结果进行惩罚
    pscore = penalty * score

    # cos window (motion model)
    # 窗口惩罚:按一定权值叠加一个窗分布值
    pscore = pscore * (1 - p.window_influence) + window * p.window_influence
    # 获取最优权值的索引
    best_pscore_id = np.argmax(pscore)
    # 将最优的预测结果映射回原图
    pred_in_crop = delta[:, best_pscore_id] / scale_x
    # 计算lr
    lr = penalty[best_pscore_id] * score[best_pscore_id] * p.lr  # lr for OTB
    # 计算目标的位置和尺寸:根据预测偏移得到目标位置和尺寸
    res_x = pred_in_crop[0] + target_pos[0]
    res_y = pred_in_crop[1] + target_pos[1]

    res_w = target_sz[0] * (1 - lr) + pred_in_crop[2] * lr
    res_h = target_sz[1] * (1 - lr) + pred_in_crop[3] * lr
    # 目标的位置和尺寸
    target_pos = np.array([res_x, res_y])
    target_sz = np.array([res_w, res_h])

    # for Mask Branch
    # 若进行分割
    if mask_enable:
        # 获取最优预测结果的位置索引:np.unravel_index:将平面索引或平面索引数组转换为坐标数组的元组
        best_pscore_id_mask = np.unravel_index(best_pscore_id, (5, p.score_size, p.score_size))
        delta_x, delta_y = best_pscore_id_mask[2], best_pscore_id_mask[1]
        # 是否进行特征融合
        if refine_enable:
            # 调用track_refine,运行 Refine 模块,由相关特征图上 1×1×256 的特征向量与检测下采样前的特征图得到目标掩膜
            mask = net.track_refine((delta_y, delta_x)).to(device).sigmoid().squeeze().view(
                p.out_size, p.out_size).cpu().data.numpy()
        else:
            # 不进行融合时直接生成掩膜数据
            mask = mask[0, :, delta_y, delta_x].sigmoid(). \
                squeeze().view(p.out_size, p.out_size).cpu().data.numpy()

        def crop_back(image, bbox, out_sz, padding=-1):
            """
            对图像进行仿射变换
            :param image: 图像
            :param bbox:
            :param out_sz: 输出尺寸
            :param padding: 是否进行扩展
            :return: 仿射变换后的结果
            """
            # 构造变换矩阵
            # 尺度系数
            a = (out_sz[0] - 1) / bbox[2]
            b = (out_sz[1] - 1) / bbox[3]
            # 平移量
            c = -a * bbox[0]
            d = -b * bbox[1]
            mapping = np.array([[a, 0, c],
                                [0, b, d]]).astype(np.float)
            # 进行仿射变换
            crop = cv2.warpAffine(image, mapping, (out_sz[0], out_sz[1]),
                                  flags=cv2.INTER_LINEAR,
                                  borderMode=cv2.BORDER_CONSTANT,
                                  borderValue=padding)
            return crop
        # 检测区域框长度与输入模型的大小的比值:缩放系数
        s = crop_box[2] / p.instance_size
        # 预测的模板区域框
        sub_box = [crop_box[0] + (delta_x - p.base_size / 2) * p.total_stride * s,
                   crop_box[1] + (delta_y - p.base_size / 2) * p.total_stride * s,
                   s * p.exemplar_size, s * p.exemplar_size]
        # 缩放系数
        s = p.out_size / sub_box[2]
        # 背景框
        back_box = [-sub_box[0] * s, -sub_box[1] * s, state['im_w'] * s, state['im_h'] * s]
        # 仿射变换
        mask_in_img = crop_back(mask, back_box, (state['im_w'], state['im_h']))
        # 得到掩膜结果
        target_mask = (mask_in_img > p.seg_thr).astype(np.uint8)
        # 根据cv2的版本查找轮廓
        if cv2.__version__[-5] == '4':
            # opencv4中返回的参数只有两个,其他版本有四个
            contours, _ = cv2.findContours(target_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
        else:
            _, contours, _ = cv2.findContours(target_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
        # 获取轮廓的面积
        cnt_area = [cv2.contourArea(cnt) for cnt in contours]
        if len(contours) != 0 and np.max(cnt_area) > 100:
            # 获取面积最大的轮廓
            contour = contours[np.argmax(cnt_area)]  # use max area polygon
            # 转换为...*2的形式
            polygon = contour.reshape(-1, 2)
            # pbox = cv2.boundingRect(polygon)  # Min Max Rectangle
            # 得到最小外接矩形后找到该矩形的四个顶点
            prbox = cv2.boxPoints(cv2.minAreaRect(polygon))  # Rotated Rectangle

            # box_in_img = pbox
            # 获得跟踪框
            rbox_in_img = prbox
        else:  # empty mask
            # 根据预测的目标位置和尺寸得到location
            location = cxy_wh_2_rect(target_pos, target_sz)
            # 得到跟踪框的四个顶点
            rbox_in_img = np.array([[location[0], location[1]],
                                    [location[0] + location[2], location[1]],
                                    [location[0] + location[2], location[1] + location[3]],
                                    [location[0], location[1] + location[3]]])
    # 得到目标的位置和尺寸
    target_pos[0] = max(0, min(state['im_w'], target_pos[0]))
    target_pos[1] = max(0, min(state['im_h'], target_pos[1]))
    target_sz[0] = max(10, min(state['im_w'], target_sz[0]))
    target_sz[1] = max(10, min(state['im_h'], target_sz[1]))
    # 更新state对象
    state['target_pos'] = target_pos
    state['target_sz'] = target_sz
    state['score'] = score[best_pscore_id]
    state['mask'] = mask_in_img if mask_enable else []
    state['ploygon'] = rbox_in_img if mask_enable else []
    return state


def track_vot(model, video, hp=None, mask_enable=False, refine_enable=False, device='cpu'):
    """
    对目标进行追踪
    :param model: 训练好的模型
    :param video: 视频数据
    :param hp: 超参数
    :param mask_enable: 是否生成掩膜,默认为False
    :param refine_enable: 是否使用融合后的模型
    :param device:硬件信息
    :return:目标跟丢次数,fps
    """
    # 记录目标框及其状态
    regions = []  # result and states[1 init / 2 lost / 0 skip]
    # 获取要处理的图像,和真实值groundtruth
    image_files, gt = video['image_files'], video['gt']
    # 设置相关参数:初始帧,终止帧,目标丢失次数,toc
    start_frame, end_frame, lost_times, toc = 0, len(image_files), 0, 0
    # 遍历要处理的图像
    for f, image_file in enumerate(image_files):
        # 读取图像
        im = cv2.imread(image_file)
        tic = cv2.getTickCount()
        # 若为初始帧图像
        if f == start_frame:  # init
            # 获取目标区域的位置:中心点坐标,宽,高
            cx, cy, w, h = get_axis_aligned_bbox(gt[f])
            # 目标位置
            target_pos = np.array([cx, cy])
            # 目标大小
            target_sz = np.array([w, h])
            # 初始化跟踪器
            state = siamese_init(im, target_pos, target_sz, model, hp, device)  # init tracker
            # 将目标框转换为:左上角坐标,宽,高的形式
            location = cxy_wh_2_rect(state['target_pos'], state['target_sz'])
            # 若数据集是VOT,在regions中添加1,否则添加gt[f],第一帧目标的真实位置
            regions.append(1 if 'VOT' in args.dataset else gt[f])
        # 非初始帧数据
        elif f > start_frame:  # tracking
            # 进行目标追踪
            state = siamese_track(state, im, mask_enable, refine_enable, device, args.debug)  # track
            # 若进行掩膜处理
            if mask_enable:
                # 将跟踪结果铺展开
                location = state['ploygon'].flatten()
                # 获得掩码
                mask = state['mask']
            # 不进行掩膜处理
            else:
                # 将目标框表示形式转换为:左上角坐标,宽,高的形式
                location = cxy_wh_2_rect(state['target_pos'], state['target_sz'])
                # 掩膜为空
                mask = []
            # 如果是VOT数据,计算交叠程度,其他数据默认交叠为1
            if 'VOT' in args.dataset:
                # 目标的真实位置
                gt_polygon = ((gt[f][0], gt[f][1]), (gt[f][2], gt[f][3]),
                              (gt[f][4], gt[f][5]), (gt[f][6], gt[f][7]))
                # 若进行掩膜处理
                if mask_enable:
                    # 预测结果为:
                    pred_polygon = ((location[0], location[1]), (location[2], location[3]),
                                    (location[4], location[5]), (location[6], location[7]))
                # 若不进行掩膜
                else:
                    # 预测结果为:
                    pred_polygon = ((location[0], location[1]),
                                    (location[0] + location[2], location[1]),
                                    (location[0] + location[2], location[1] + location[3]),
                                    (location[0], location[1] + location[3]))
                # 计算两个目标之间的交叠程度
                b_overlap = vot_overlap(gt_polygon, pred_polygon, (im.shape[1], im.shape[0]))
            else:
                b_overlap = 1
            # 如果跟踪框和真实框有交叠,添加跟踪结果中
            if b_overlap:
                regions.append(location)
            # 如果跟丢,则记录跟丢次数,五帧后重新进行目标初始化
            else:  # lost
                regions.append(2)
                lost_times += 1
                start_frame = f + 5  # skip 5 frames
        # 其他帧数据跳过(比如小于初始帧的数据)
        else:  # skip
            regions.append(0)
        # 计算跟踪时间
        toc += cv2.getTickCount() - tic
        # 如果进行显示并且跳过丢失的帧数据
        if args.visualization and f >= start_frame:  # visualization (skip lost frame)
            # 复制原图像的副本
            im_show = im.copy()
            # 如果帧数为0,销毁窗口
            if f == 0: cv2.destroyAllWindows()
            # 标注信息中包含第f帧的结果时:
            if gt.shape[0] > f:
                # 将标准的真实信息绘制在图像上
                if len(gt[f]) == 8:
                    cv2.polylines(im_show, [np.array(gt[f], np.int).reshape((-1, 1, 2))], True, (0, 255, 0), 3)
                else:
                    cv2.rectangle(im_show, (gt[f, 0], gt[f, 1]), (gt[f, 0] + gt[f, 2], gt[f, 1] + gt[f, 3]), (0, 255, 0), 3)
            # 将跟踪结果绘制在图像上
            if len(location) == 8:
                # 若进行掩膜处理,将掩膜结果绘制在图像上
                if mask_enable:
                    mask = mask > state['p'].seg_thr
                    im_show[:, :, 2] = mask * 255 + (1 - mask) * im_show[:, :, 2]
                location_int = np.int0(location)
                cv2.polylines(im_show, [location_int.reshape((-1, 1, 2))], True, (0, 255, 255), 3)
            else:
                location = [int(l) for l in location]
                cv2.rectangle(im_show, (location[0], location[1]),
                              (location[0] + location[2], location[1] + location[3]), (0, 255, 255), 3)
            cv2.putText(im_show, str(f), (40, 40), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 255), 2)
            cv2.putText(im_show, str(lost_times), (40, 80), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)
            cv2.putText(im_show, str(state['score']) if 'score' in state else '', (40, 120), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)

            cv2.imshow(video['name'], im_show)
            cv2.waitKey(1)
    toc /= cv2.getTickFrequency()

    # 结果保存到文本文件中 save result
    # 文件夹名称:包括模型结构、mask、refine、resume信息
    name = args.arch.split('.')[0] + '_' + ('mask_' if mask_enable else '') + ('refine_' if refine_enable else '') +\
           args.resume.split('/')[-1].split('.')[0]
    # 如果是VOT数据集
    if 'VOT' in args.dataset:
        # 构建追踪结果存储位置
        video_path = join('test', args.dataset, name,
                          'baseline', video['name'])
        # 若不存在该路径,进行创建
        if not isdir(video_path): makedirs(video_path)
        # 文本文件的路径
        result_path = join(video_path, '{:s}_001.txt'.format(video['name']))
        # 将追踪结果写入文本文件中
        # with open(result_path, "w") as fin:
        #     for x in regions:
        #         fin.write("{:d}\n".format(x)) if isinstance(x, int) else \
        #                 fin.write(','.join([vot_float2str("%.4f", i) for i in x]) + '\n')
    # 如果是OTB数据
    else:  # OTB
        # 构建存储路径
        video_path = join('test', args.dataset, name)
        # 若不存在该路径,进行创建
        if not isdir(video_path): makedirs(video_path)
        # 文本文件的路径
        result_path = join(video_path, '{:s}.txt'.format(video['name']))
        # 将追踪结果写入文本文件中
        with open(result_path, "w") as fin:
            for x in regions:
                fin.write(','.join([str(i) for i in x])+'\n')
    # 将信息写入到log文件中
    logger.info('({:d}) Video: {:12s} Time: {:02.1f}s Speed: {:3.1f}fps Lost: {:d}'.format(
        v_id, video['name'], toc, f / toc, lost_times))
    # 返回结果
    return lost_times, f / toc


def MultiBatchIouMeter(thrs, outputs, targets, start=None, end=None):
    """
    批量计算某个目标在视频(多帧图像)中的IOU
    :param thrs:阈值
    :param outputs:追踪的目标结果
    :param targets:真实的目标结果
    :param start:起止帧
    :param end:终止帧
    :return:某个目标的区域相似度
    """
    # 将追踪结果与真实结果转换为ndarray的形式
    targets = np.array(targets)
    outputs = np.array(outputs)
    # 利用标注信息获取视频的帧数
    num_frame = targets.shape[0]
    # 若未指定初始帧
    if start is None:
        # 根据目标跟踪结果确定目标ids
        object_ids = np.array(list(range(outputs.shape[0]))) + 1
    else:
        # 根据指定初始帧确定目标的ids
        object_ids = [int(id) for id in start]
    # 确定目标个数
    num_object = len(object_ids)
    # 用来存储某一目标的交并比
    res = np.zeros((num_object, len(thrs)), dtype=np.float32)
    # 计算掩膜中的最大值及其所在id(该位置认为是目标的位置)
    output_max_id = np.argmax(outputs, axis=0).astype('uint8')+1
    outputs_max = np.max(outputs, axis=0)
    # 遍历阈值
    for k, thr in enumerate(thrs):
        # 若追踪的max大于阈值, output_thr设为1,否则设为0
        output_thr = outputs_max > thr
        # 遍历追踪的目标
        for j in range(num_object):
            # 得到指定的目标
            target_j = targets == object_ids[j]
            # 确定目标所在的视频帧数
            if start is None:
                start_frame, end_frame = 1, num_frame - 1
            else:
                start_frame, end_frame = start[str(object_ids[j])] + 1, end[str(object_ids[j])] - 1
            # 交并比
            iou = []
            # 遍历帧
            for i in range(start_frame, end_frame):
                # 找到追踪结果为j的位置置为1
                pred = (output_thr[i] * output_max_id[i]) == (j+1)
                # 计算真值和追踪结果的和
                mask_sum = (pred == 1).astype(np.uint8) + (target_j[i] > 0).astype(np.uint8)
                # 计算交
                intxn = np.sum(mask_sum == 2)
                # 计算并
                union = np.sum(mask_sum > 0)
                # 计算交并比
                if union > 0:
                    iou.append(intxn / union)
                elif union == 0 and intxn == 0:
                    iou.append(1)
            # 计算目标j,阈值为k时的平均交并比
            res[j, k] = np.mean(iou)
    return res


def track_vos(model, video, hp=None, mask_enable=False, refine_enable=False, mot_enable=False, device='cpu'):
    """
    对数据进行分割并追踪
    :param model: 训练好的模型
    :param video: 视频数据
    :param hp: 超参数
    :param mask_enable: 是否生成掩膜,默认为False
    :param refine_enable: 是否使用融合后的模型
    :param mot_enable:是否进行多目标追踪
    :param device:硬件信息
    :return:区域相似度(掩膜与真值之间的IOU),fps
    """
    # 要处理的图像序列
    image_files = video['image_files']
    # 标注信息:分割中标注的内容也是图像
    annos = [np.array(Image.open(x)) for x in video['anno_files']]
    # 获取初始帧的标注信息
    if 'anno_init_files' in video:
        annos_init = [np.array(Image.open(x)) for x in video['anno_init_files']]
    else:
        annos_init = [annos[0]]
    # 如不进行多目标跟踪,则把多个实例合并为一个示例后进行跟踪
    if not mot_enable:
        # 将标注信息中大于0的置为1,存为掩膜的形式
        annos = [(anno > 0).astype(np.uint8) for anno in annos]
        annos_init = [(anno_init > 0).astype(np.uint8) for anno_init in annos_init]
    # 统计起始帧图像中的目标id
    if 'start_frame' in video:
        object_ids = [int(id) for id in video['start_frame']]
    else:
        # 若起始帧不存在,则根据初始的标注信息确定目标id
        object_ids = [o_id for o_id in np.unique(annos[0]) if o_id != 0]
        # 若目标idgeshu小于初始帧的标注个数时,说明不进行多目标追踪,合并后的标注信息作为每个目标的标准信息
        if len(object_ids) != len(annos_init):
            annos_init = annos_init*len(object_ids)
    # 统计跟踪目标个数
    object_num = len(object_ids)

    toc = 0
    # 用来存放每一帧图像的掩模信息
    pred_masks = np.zeros((object_num, len(image_files), annos[0].shape[0], annos[0].shape[1]))-1
    # 遍历每一个目标,在起止帧之间进行目标跟踪
    for obj_id, o_id in enumerate(object_ids):
        # 确定起止帧的id
        if 'start_frame' in video:
            start_frame = video['start_frame'][str(o_id)]
            end_frame = video['end_frame'][str(o_id)]
        else:
            start_frame, end_frame = 0, len(image_files)
        # 遍历每一帧图像
        for f, image_file in enumerate(image_files):
            im = cv2.imread(image_file)
            tic = cv2.getTickCount()
            # 若是起始帧图像,进行初始化
            if f == start_frame:  # init
                # 确定目标o_id的掩模
                mask = annos_init[obj_id] == o_id
                # 计算mask垂直边界的最小矩形(矩形与图像的上下边界平行)
                x, y, w, h = cv2.boundingRect((mask).astype(np.uint8))
                # 计算边界矩形的中心坐标
                cx, cy = x + w/2, y + h/2
                # 目标位置:矩形中心
                target_pos = np.array([cx, cy])
                # 目标尺寸:矩形尺寸
                target_sz = np.array([w, h])
                # 初始化跟踪器
                state = siamese_init(im, target_pos, target_sz, model, hp, device=device)  # init tracker
            # 若非起始帧图像,则执行跟踪操作
            elif end_frame >= f > start_frame:  # tracking
                # 目标跟踪
                state = siamese_track(state, im, mask_enable, refine_enable, device=device)  # track
                # 某一帧图像掩膜信息
                mask = state['mask']
            toc += cv2.getTickCount() - tic
            # 所有帧图像,更新掩膜信息
            if end_frame >= f >= start_frame:
                # 更新所有帧图像的某一目标的掩膜
                pred_masks[obj_id, f, :, :] = mask
    toc /= cv2.getTickFrequency()
    # 若标注信息与测试图像长度一致,计算区域相似度
    if len(annos) == len(image_files):
        # 批量计算IOU
        multi_mean_iou = MultiBatchIouMeter(thrs, pred_masks, annos,
                                            start=video['start_frame'] if 'start_frame' in video else None,
                                            end=video['end_frame'] if 'end_frame' in video else None)
        # 将每一目标的IOU写入到日志文件中
        for i in range(object_num):
            for j, thr in enumerate(thrs):
                logger.info('Fusion Multi Object{:20s} IOU at {:.2f}: {:.4f}'.format(video['name'] + '_' + str(i + 1), thr,
                                                                           multi_mean_iou[i, j]))
    else:
        multi_mean_iou = []
    # 保存掩膜
    if args.save_mask:
        video_path = join('test', args.dataset, 'SiamMask', video['name'])
        if not isdir(video_path): makedirs(video_path)
        pred_mask_final = np.array(pred_masks)
        pred_mask_final = (np.argmax(pred_mask_final, axis=0).astype('uint8') + 1) * (
                np.max(pred_mask_final, axis=0) > state['p'].seg_thr).astype('uint8')
        for i in range(pred_mask_final.shape[0]):
            cv2.imwrite(join(video_path, image_files[i].split('/')[-1].split('.')[0] + '.png'), pred_mask_final[i].astype(np.uint8))
    # 显示,因为是全部处理完成后进行显示,会有卡顿
    if args.visualization:
        pred_mask_final = np.array(pred_masks)
        pred_mask_final = (np.argmax(pred_mask_final, axis=0).astype('uint8') + 1) * (
                np.max(pred_mask_final, axis=0) > state['p'].seg_thr).astype('uint8')
        COLORS = np.random.randint(128, 255, size=(object_num, 3), dtype="uint8")
        COLORS = np.vstack([[0, 0, 0], COLORS]).astype("uint8")
        mask = COLORS[pred_mask_final]
        for f, image_file in enumerate(image_files):
            output = ((0.4 * cv2.imread(image_file)) + (0.6 * mask[f,:,:,:])).astype("uint8")
            cv2.imshow("mask", output)
            cv2.waitKey(1)

    logger.info('({:d}) Video: {:12s} Time: {:02.1f}s Speed: {:3.1f}fps'.format(
        v_id, video['name'], toc, f*len(object_ids) / toc))

    return multi_mean_iou, f*len(object_ids) / toc


def main():
    # 获取命令行参数信息
    global args, logger, v_id
    args = parser.parse_args()
    # 获取配置文件中配置信息:主要包括网络结构,超参数等
    cfg = load_config(args)
    # 初始化logxi信息,并将日志信息输入到磁盘文件中
    init_log('global', logging.INFO)
    if args.log != "":
        add_file_handler('global', args.log, logging.INFO)
    # 将相关的配置信息输入到日志文件中
    logger = logging.getLogger('global')
    logger.info(args)

    # setup model
    # 加载网络模型架构
    if args.arch == 'Custom':
        from custom import Custom
        model = Custom(anchors=cfg['anchors'])
    else:
        parser.error('invalid architecture: {}'.format(args.arch))
    # 加载网络模型参数
    if args.resume:
        assert isfile(args.resume), '{} is not a valid file'.format(args.resume)
        model = load_pretrain(model, args.resume)
    # 使用评估模式,将drop等激活
    model.eval()
    # 硬件信息
    device = torch.device('cuda' if (torch.cuda.is_available() and not args.cpu) else 'cpu')
    model = model.to(device)
    # 加载数据集 setup dataset
    dataset = load_dataset(args.dataset)

    # 这三种数据支持掩膜 VOS or VOT?
    if args.dataset in ['DAVIS2016', 'DAVIS2017', 'ytb_vos'] and args.mask:
        vos_enable = True  # enable Mask output
    else:
        vos_enable = False

    total_lost = 0  # VOT
    iou_lists = []  # VOS
    speed_list = []
    # 对数据进行处理
    for v_id, video in enumerate(dataset.keys(), start=1):
        if args.video != '' and video != args.video:
            continue
        # true 调用track_vos
        if vos_enable:
            # 如测试数据是['DAVIS2017', 'ytb_vos']时,会开启多目标跟踪
            iou_list, speed = track_vos(model, dataset[video], cfg['hp'] if 'hp' in cfg.keys() else None,
                                 args.mask, args.refine, args.dataset in ['DAVIS2017', 'ytb_vos'], device=device)
            iou_lists.append(iou_list)
        # False 调用track_vot
        else:
            lost, speed = track_vot(model, dataset[video], cfg['hp'] if 'hp' in cfg.keys() else None,
                             args.mask, args.refine, device=device)
            total_lost += lost
        speed_list.append(speed)

    # report final result
    if vos_enable:
        for thr, iou in zip(thrs, np.mean(np.concatenate(iou_lists), axis=0)):
            logger.info('Segmentation Threshold {:.2f} mIoU: {:.3f}'.format(thr, iou))
    else:
        logger.info('Total Lost: {:d}'.format(total_lost))

    logger.info('Mean Speed: {:.2f} FPS'.format(np.mean(speed_list)))


if __name__ == '__main__':
    main()

train_siammask.py

# --------------------------------------------------------
# 基础网络的训练
# --------------------------------------------------------
import argparse
import logging
import os
import cv2
import shutil
import time
import json
import math
import torch
from torch.utils.data import DataLoader

from utils.log_helper import init_log, print_speed, add_file_handler, Dummy
from utils.load_helper import load_pretrain, restore_from
from utils.average_meter_helper import AverageMeter

from datasets.siam_mask_dataset import DataSets

from utils.lr_helper import build_lr_scheduler
from tensorboardX import SummaryWriter

from utils.config_helper import load_config
from torch.utils.collect_env import get_pretty_env_info

torch.backends.cudnn.benchmark = True

parser = argparse.ArgumentParser(description='PyTorch Tracking SiamMask Training')

parser.add_argument('-j', '--workers', default=0, type=int, metavar='N',
                    help='number of data loading workers (default: 16)')
parser.add_argument('--epochs', default=50, type=int, metavar='N',
                    help='number of total epochs to run')
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
                    help='manual epoch number (useful on restarts)')
parser.add_argument('-b', '--batch', default=64, type=int,
                    metavar='N', help='mini-batch size (default: 64)')
parser.add_argument('--lr', '--learning-rate', default=0.001, type=float,
                    metavar='LR', help='initial learning rate')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
                    help='momentum')
parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,
                    metavar='W', help='weight decay (default: 1e-4)')
parser.add_argument('--clip', default=10.0, type=float,
                    help='gradient clip value')
parser.add_argument('--print-freq', '-p', default=10, type=int,
                    metavar='N', help='print frequency (default: 10)')
parser.add_argument('--resume', default='', type=str, metavar='PATH',
                    help='path to latest checkpoint (default: none)')
parser.add_argument('--pretrained', dest='pretrained', default='',
                    help='use pre-trained model')
parser.add_argument('--config', dest='config', required=True,
                    help='hyperparameter of SiamMask in json format')
parser.add_argument('--arch', dest='arch', default='', choices=['Custom',],
                    help='architecture of pretrained model')
parser.add_argument('-l', '--log', default="log.txt", type=str,
                    help='log file')
parser.add_argument('-s', '--save_dir', default='snapshot', type=str,
                    help='save dir')
parser.add_argument('--log-dir', default='board', help='TensorBoard log dir')


best_acc = 0.


def collect_env_info():
    """
    环境信息
    :return:
    """
    env_str = get_pretty_env_info()
    env_str += "\n        OpenCV ({})".format(cv2.__version__)
    return env_str


def build_data_loader(cfg):
    """
    获取数据集
    :param cfg:
    :return:
    """
    logger = logging.getLogger('global')

    logger.info("build train dataset")  # train_dataset
    # 获取训练集数据,包含数据增强的内容
    train_set = DataSets(cfg['train_datasets'], cfg['anchors'], args.epochs)
    # 对数据进行打乱处理
    train_set.shuffle()

    # 获取验证集数据,若为配置验证集数据则使用训练集数据替代
    logger.info("build val dataset")  # val_dataset
    if not 'val_datasets' in cfg.keys():
        cfg['val_datasets'] = cfg['train_datasets']
    val_set = DataSets(cfg['val_datasets'], cfg['anchors'])
    val_set.shuffle()
    # DataLoader是Torch内置的方法,它允许使用多线程加速数据的读取
    train_loader = DataLoader(train_set, batch_size=args.batch, num_workers=args.workers,
                              pin_memory=True, sampler=None)
    val_loader = DataLoader(val_set, batch_size=args.batch, num_workers=args.workers,
                            pin_memory=True, sampler=None)

    logger.info('build dataset done')
    return train_loader, val_loader


def build_opt_lr(model, cfg, args, epoch):
    """
    获取优化方法和学习率
    :param model:
    :param cfg:
    :param args:
    :param epoch:
    :return:
    """
    # 获取要训练的网络
    backbone_feature = model.features.param_groups(cfg['lr']['start_lr'], cfg['lr']['feature_lr_mult'])
    if len(backbone_feature) == 0:
        # 获取要训练的rpn网络的参数
        trainable_params = model.rpn_model.param_groups(cfg['lr']['start_lr'], cfg['lr']['rpn_lr_mult'], 'mask')
    else:
        # 获取基础网络,rpn和mask网络的训练参数
        trainable_params = backbone_feature + \
                           model.rpn_model.param_groups(cfg['lr']['start_lr'], cfg['lr']['rpn_lr_mult']) + \
                           model.mask_model.param_groups(cfg['lr']['start_lr'], cfg['lr']['mask_lr_mult'])
    # 随机梯度下降算法优化
    optimizer = torch.optim.SGD(trainable_params, args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    # 获取学习率
    lr_scheduler = build_lr_scheduler(optimizer, cfg['lr'], epochs=args.epochs)
    # 更新学习率
    lr_scheduler.step(epoch)
    # 返回优化器和学习率
    return optimizer, lr_scheduler


def main():
    """
    基础网络的训练
    :return:
    """
    global args, best_acc, tb_writer, logger
    args = parser.parse_args()
    # 初始化日志信息
    init_log('global', logging.INFO)

    if args.log != "":
        add_file_handler('global', args.log, logging.INFO)
    # 获取log信息
    logger = logging.getLogger('global')
    logger.info("\n" + collect_env_info())
    logger.info(args)
    # 获取配置信息
    cfg = load_config(args)
    logger.info("config \n{}".format(json.dumps(cfg, indent=4)))

    if args.log_dir:
        tb_writer = SummaryWriter(args.log_dir)
    else:
        tb_writer = Dummy()

    # 构建数据集
    train_loader, val_loader = build_data_loader(cfg)
    # 加载训练网络
    if args.arch == 'Custom':
        from custom import Custom
        model = Custom(pretrain=True, anchors=cfg['anchors'])
    else:
        exit()
    logger.info(model)
    # 加载预训练网络
    if args.pretrained:
        model = load_pretrain(model, args.pretrained)

    # GPU版本
    # model = model.cuda()
    # dist_model = torch.nn.DataParallel(model, list(range(torch.cuda.device_count()))).cuda()
    # 网络模型
    dist_model = torch.nn.DataParallel(model)
    # 模型参数的更新比例
    if args.resume and args.start_epoch != 0:
        model.features.unfix((args.start_epoch - 1) / args.epochs)
    # 获取优化器和学习率的更新策略
    optimizer, lr_scheduler = build_opt_lr(model, cfg, args, args.start_epoch)
    # optionally resume from a checkpoint 加载模型
    if args.resume:
        assert os.path.isfile(args.resume), '{} is not a valid file'.format(args.resume)
        model, optimizer, args.start_epoch, best_acc, arch = restore_from(model, optimizer, args.resume)
        # GPU
        # dist_model = torch.nn.DataParallel(model, list(range(torch.cuda.device_count()))).cuda()
        dist_model = torch.nn.DataParallel(model)

    logger.info(lr_scheduler)

    logger.info('model prepare done')
    # 模型训练
    train(train_loader, dist_model, optimizer, lr_scheduler, args.start_epoch, cfg)


def train(train_loader, model, optimizer, lr_scheduler, epoch, cfg):
    """
    模型训练
    :param train_loader:训练数据
    :param model:
    :param optimizer:
    :param lr_scheduler:
    :param epoch:
    :param cfg:
    :return:
    """

    global tb_index, best_acc, cur_lr, logger
    # 获取当前的学习率
    cur_lr = lr_scheduler.get_cur_lr()
    logger = logging.getLogger('global')
    #
    avg = AverageMeter()
    model.train()
    # GPU
    #  model = model.cuda()
    end = time.time()

    def is_valid_number(x):
        return not(math.isnan(x) or math.isinf(x) or x > 1e4)

    num_per_epoch = len(train_loader.dataset) // args.epochs // args.batch
    print("num_per_epoch",num_per_epoch)
    start_epoch = epoch
    epoch = epoch
    # 获取每个batch的输入
    for iter, input in enumerate(train_loader):
        if epoch != iter // num_per_epoch + start_epoch:  # next epoch
            epoch = iter // num_per_epoch + start_epoch
            # 创建存储路径
            if not os.path.exists(args.save_dir):  # makedir/save model
                os.makedirs(args.save_dir)
            # 存储训练结果
            save_checkpoint({
                    'epoch': epoch,
                    'arch': args.arch,
                    'state_dict': model.module.state_dict(),
                    'best_acc': best_acc,
                    'optimizer': optimizer.state_dict(),
                    'anchor_cfg': cfg['anchors']
                }, False,
                os.path.join(args.save_dir, 'checkpoint_e%d.pth' % (epoch)),
                os.path.join(args.save_dir, 'best.pth'))

            if epoch == args.epochs:
                return
            # 更新优化器和学习方法
            if model.module.features.unfix(epoch/args.epochs):
                logger.info('unfix part model.')
                optimizer, lr_scheduler = build_opt_lr(model.module, cfg, args, epoch)
            # 获取当前学习率
            lr_scheduler.step(epoch)
            cur_lr = lr_scheduler.get_cur_lr()

            logger.info('epoch:{}'.format(epoch))
        # 更新日志
        tb_index = iter
        if iter % num_per_epoch == 0 and iter != 0:
            for idx, pg in enumerate(optimizer.param_groups):
                logger.info("epoch {} lr {}".format(epoch, pg['lr']))
                tb_writer.add_scalar('lr/group%d' % (idx+1), pg['lr'], tb_index)

        data_time = time.time() - end
        avg.update(data_time=data_time)
        # 输入数据
        x = {
            # GPU
            # 'cfg': cfg,
            # 'template': torch.autograd.Variable(input[0]).cuda(),
            # 'search': torch.autograd.Variable(input[1]).cuda(),
            # 'label_cls': torch.autograd.Variable(input[2]).cuda(),
            # 'label_loc': torch.autograd.Variable(input[3]).cuda(),
            # 'label_loc_weight': torch.autograd.Variable(input[4]).cuda(),
            # 'label_mask': torch.autograd.Variable(input[6]).cuda(),
            # 'label_mask_weight': torch.autograd.Variable(input[7]).cuda(),
            'cfg': cfg,
            'template': torch.autograd.Variable(input[0]),
            'search': torch.autograd.Variable(input[1]),
            'label_cls': torch.autograd.Variable(input[2]),
            'label_loc': torch.autograd.Variable(input[3]),
            'label_loc_weight': torch.autograd.Variable(input[4]),
            'label_mask': torch.autograd.Variable(input[6]),
            'label_mask_weight': torch.autograd.Variable(input[7]),
        }
        # 输出数据
        outputs = model(x)

        # 计算损失函数
        rpn_cls_loss, rpn_loc_loss, rpn_mask_loss = torch.mean(outputs['losses'][0]), torch.mean(outputs['losses'][1]), torch.mean(outputs['losses'][2])
        # 计算精度
        mask_iou_mean, mask_iou_at_5, mask_iou_at_7 = torch.mean(outputs['accuracy'][0]), torch.mean(outputs['accuracy'][1]), torch.mean(outputs['accuracy'][2])
        # 获取分类,回归和分割所占的比例
        cls_weight, reg_weight, mask_weight = cfg['loss']['weight']
        # 计算总损失
        loss = rpn_cls_loss * cls_weight + rpn_loc_loss * reg_weight + rpn_mask_loss * mask_weight
        # 将梯度置零
        optimizer.zero_grad()
        # 反向传播
        loss.backward()

        if cfg['clip']['split']:
            torch.nn.utils.clip_grad_norm_(model.module.features.parameters(), cfg['clip']['feature'])
            torch.nn.utils.clip_grad_norm_(model.module.rpn_model.parameters(), cfg['clip']['rpn'])
            torch.nn.utils.clip_grad_norm_(model.module.mask_model.parameters(), cfg['clip']['mask'])
        else:
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)  # gradient clip

        if is_valid_number(loss.item()):
            optimizer.step()

        siammask_loss = loss.item()

        batch_time = time.time() - end
        # 参数更新
        avg.update(batch_time=batch_time, rpn_cls_loss=rpn_cls_loss, rpn_loc_loss=rpn_loc_loss,
                   rpn_mask_loss=rpn_mask_loss, siammask_loss=siammask_loss,
                   mask_iou_mean=mask_iou_mean, mask_iou_at_5=mask_iou_at_5, mask_iou_at_7=mask_iou_at_7)
        # 参数写入tensorboard
        tb_writer.add_scalar('loss/cls', rpn_cls_loss, tb_index)
        tb_writer.add_scalar('loss/loc', rpn_loc_loss, tb_index)
        tb_writer.add_scalar('loss/mask', rpn_mask_loss, tb_index)
        tb_writer.add_scalar('mask/mIoU', mask_iou_mean, tb_index)
        tb_writer.add_scalar('mask/AP@.5', mask_iou_at_5, tb_index)
        tb_writer.add_scalar('mask/AP@.7', mask_iou_at_7, tb_index)
        end = time.time()
        # 日志输出
        if (iter + 1) % args.print_freq == 0:
            logger.info('Epoch: [{0}][{1}/{2}] lr: {lr:.6f}\t{batch_time:s}\t{data_time:s}'
                        '\t{rpn_cls_loss:s}\t{rpn_loc_loss:s}\t{rpn_mask_loss:s}\t{siammask_loss:s}'
                        '\t{mask_iou_mean:s}\t{mask_iou_at_5:s}\t{mask_iou_at_7:s}'.format(
                        epoch+1, (iter + 1) % num_per_epoch, num_per_epoch, lr=cur_lr, batch_time=avg.batch_time,
                        data_time=avg.data_time, rpn_cls_loss=avg.rpn_cls_loss, rpn_loc_loss=avg.rpn_loc_loss,
                        rpn_mask_loss=avg.rpn_mask_loss, siammask_loss=avg.siammask_loss, mask_iou_mean=avg.mask_iou_mean,
                        mask_iou_at_5=avg.mask_iou_at_5,mask_iou_at_7=avg.mask_iou_at_7))
            print_speed(iter + 1, avg.batch_time.avg, args.epochs * num_per_epoch)


def save_checkpoint(state, is_best, filename='checkpoint.pth', best_file='model_best.pth'):
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, best_file)


if __name__ == '__main__':
    main()

train_siammask_refine.py

# --------------------------------------------------------
# SiamMask
# Licensed under The MIT License
# Written by Qiang Wang (wangqiang2015 at ia.ac.cn)
# 该文件中训练掩膜细化网络,与基础网络的训练类似,加载的模型是包含refine的模型
# --------------------------------------------------------
import argparse
import logging
import os
import cv2
import shutil
import time
import json
import math
import torch
from torch.utils.data import DataLoader

from utils.log_helper import init_log, print_speed, add_file_handler, Dummy
from utils.load_helper import load_pretrain, restore_from
from utils.average_meter_helper import AverageMeter

from datasets.siam_mask_dataset import DataSets

from utils.lr_helper import build_lr_scheduler
from tensorboardX import SummaryWriter

from utils.config_helper import load_config
from torch.utils.collect_env import get_pretty_env_info

torch.backends.cudnn.benchmark = True


parser = argparse.ArgumentParser(description='PyTorch Tracking Training')

parser.add_argument('-j', '--workers', default=16, type=int, metavar='N',
                    help='number of data loading workers (default: 16)')
parser.add_argument('--epochs', default=50, type=int, metavar='N',
                    help='number of total epochs to run')
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
                    help='manual epoch number (useful on restarts)')
parser.add_argument('-b', '--batch', default=64, type=int,
                    metavar='N', help='mini-batch size (default: 64)')
parser.add_argument('--lr', '--learning-rate', default=0.001, type=float,
                    metavar='LR', help='initial learning rate')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
                    help='momentum')
parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,
                    metavar='W', help='weight decay (default: 1e-4)')
parser.add_argument('--clip', default=10.0, type=float,
                    help='gradient clip value')
parser.add_argument('--print-freq', '-p', default=10, type=int,
                    metavar='N', help='print frequency (default: 10)')
parser.add_argument('--resume', default='', type=str, metavar='PATH',
                    help='path to latest checkpoint (default: none)')
parser.add_argument('--pretrained', dest='pretrained', default='',
                    help='use pre-trained model')
parser.add_argument('--config', dest='config', required=True,
                    help='hyperparameter of SiamRPN in json format')
parser.add_argument('--arch', dest='arch', default='', choices=['Custom',''],
                    help='architecture of pretrained model')
parser.add_argument('-l', '--log', default="log.txt", type=str,
                    help='log file')
parser.add_argument('-s', '--save_dir', default='snapshot', type=str,
                    help='save dir')
parser.add_argument('--log-dir', default='board', help='TensorBoard log dir')


best_acc = 0.


def collect_env_info():
    env_str = get_pretty_env_info()
    env_str += "\n        OpenCV ({})".format(cv2.__version__)
    return env_str


def build_data_loader(cfg):
    logger = logging.getLogger('global')

    logger.info("build train dataset")  # train_dataset
    train_set = DataSets(cfg['train_datasets'], cfg['anchors'], args.epochs)
    train_set.shuffle()

    logger.info("build val dataset")  # val_dataset
    if not 'val_datasets' in cfg.keys():
        cfg['val_datasets'] = cfg['train_datasets']
    val_set = DataSets(cfg['val_datasets'], cfg['anchors'])
    val_set.shuffle()

    train_loader = DataLoader(train_set, batch_size=args.batch, num_workers=args.workers,
                              pin_memory=True, sampler=None)
    val_loader = DataLoader(val_set, batch_size=args.batch, num_workers=args.workers,
                            pin_memory=True, sampler=None)

    logger.info('build dataset done')
    return train_loader, val_loader


def build_opt_lr(model, cfg, args, epoch):
    '''
    对模型参数进行优化
    :param model:
    :param cfg:
    :param args:
    :param epoch:
    :return:
    '''
    trainable_params = model.mask_model.param_groups(cfg['lr']['start_lr'], cfg['lr']['mask_lr_mult']) + \
                       model.refine_model.param_groups(cfg['lr']['start_lr'], cfg['lr']['mask_lr_mult'])
    # 随机梯度下降算法:
    optimizer = torch.optim.SGD(trainable_params, args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # 获取学习率
    lr_scheduler = build_lr_scheduler(optimizer, cfg['lr'], epochs=args.epochs)
    # 更新学习率
    lr_scheduler.step(epoch)

    return optimizer, lr_scheduler


def main():
    global args, best_acc, tb_writer, logger
    args = parser.parse_args()

    init_log('global', logging.INFO)

    if args.log != "":
        add_file_handler('global', args.log, logging.INFO)

    logger = logging.getLogger('global')
    logger.info("\n" + collect_env_info())
    logger.info(args)

    cfg = load_config(args)
    logger.info("config \n{}".format(json.dumps(cfg, indent=4)))

    if args.log_dir:
        tb_writer = SummaryWriter(args.log_dir)
    else:
        tb_writer = Dummy()

    # build dataset
    train_loader, val_loader = build_data_loader(cfg)

    if args.arch == 'Custom':
        from custom import Custom
        model = Custom(anchors=cfg['anchors'])
    else:
        exit()
        # model = models.__dict__[args.arch](anchors=cfg['anchors'])

    logger.info(model)

    if args.pretrained:
        model = load_pretrain(model, args.pretrained)

    # GPU
    # model = model.cuda()
    # dist_model = torch.nn.DataParallel(model, list(range(torch.cuda.device_count()))).cuda()
    dist_model = torch.nn.DataParallel(model)

    if args.resume and args.start_epoch != 0:
        model.features.unfix((args.start_epoch - 1) / args.epochs)

    optimizer, lr_scheduler = build_opt_lr(model, cfg, args, args.start_epoch)
    # optionally resume from a checkpoint
    if args.resume:
        assert os.path.isfile(args.resume), '{} is not a valid file'.format(args.resume)
        model, optimizer, args.start_epoch, best_acc, arch = restore_from(model, optimizer, args.resume)
        # GPU
        # dist_model = torch.nn.DataParallel(model, list(range(torch.cuda.device_count()))).cuda()
        dist_model = torch.nn.DataParallel(model)

    logger.info(lr_scheduler)

    logger.info('model prepare done')

    train(train_loader, dist_model, optimizer, lr_scheduler, args.start_epoch, cfg)


def BNtoFixed(m):
    class_name = m.__class__.__name__
    if class_name.find('BatchNorm') != -1:
        m.eval()


def train(train_loader, model, optimizer, lr_scheduler, epoch, cfg):
    global tb_index, best_acc, cur_lr, logger
    cur_lr = lr_scheduler.get_cur_lr()
    logger = logging.getLogger('global')
    avg = AverageMeter()
    model.train()
    model.module.features.eval()
    model.module.rpn_model.eval()
    model.module.features.apply(BNtoFixed)
    model.module.rpn_model.apply(BNtoFixed)

    model.module.mask_model.train()
    model.module.refine_model.train()
    # GPU
    # model = model.cuda()
    end = time.time()

    def is_valid_number(x):
        return not(math.isnan(x) or math.isinf(x) or x > 1e4)

    num_per_epoch = len(train_loader.dataset) // args.epochs // args.batch
    start_epoch = epoch
    epoch = epoch
    for iter, input in enumerate(train_loader):

        if epoch != iter // num_per_epoch + start_epoch:  # next epoch
            epoch = iter // num_per_epoch + start_epoch

            if not os.path.exists(args.save_dir):  # makedir/save model
                os.makedirs(args.save_dir)

            save_checkpoint({
                    'epoch': epoch,
                    'arch': args.arch,
                    'state_dict': model.module.state_dict(),
                    'best_acc': best_acc,
                    'optimizer': optimizer.state_dict(),
                    'anchor_cfg': cfg['anchors']
                }, False,
                os.path.join(args.save_dir, 'checkpoint_e%d.pth' % (epoch)),
                os.path.join(args.save_dir, 'best.pth'))

            if epoch == args.epochs:
                return

            optimizer, lr_scheduler = build_opt_lr(model.module, cfg, args, epoch)

            lr_scheduler.step(epoch)
            cur_lr = lr_scheduler.get_cur_lr()

            logger.info('epoch:{}'.format(epoch))

        tb_index = iter
        if iter % num_per_epoch == 0 and iter != 0:
            for idx, pg in enumerate(optimizer.param_groups):
                logger.info("epoch {} lr {}".format(epoch, pg['lr']))
                tb_writer.add_scalar('lr/group%d' % (idx+1), pg['lr'], tb_index)

        data_time = time.time() - end
        avg.update(data_time=data_time)
        x = {
            # GPU
            # 'cfg': cfg,
            # 'template': torch.autograd.Variable(input[0]).cuda(),
            # 'search': torch.autograd.Variable(input[1]).cuda(),
            # 'label_cls': torch.autograd.Variable(input[2]).cuda(),
            # 'label_loc': torch.autograd.Variable(input[3]).cuda(),
            # 'label_loc_weight': torch.autograd.Variable(input[4]).cuda(),
            # 'label_mask': torch.autograd.Variable(input[6]).cuda(),
            # 'label_mask_weight': torch.autograd.Variable(input[7]).cuda(),

            'cfg': cfg,
            'template': torch.autograd.Variable(input[0]),
            'search': torch.autograd.Variable(input[1]),
            'label_cls': torch.autograd.Variable(input[2]),
            'label_loc': torch.autograd.Variable(input[3]),
            'label_loc_weight': torch.autograd.Variable(input[4]),
            'label_mask': torch.autograd.Variable(input[6]),
            'label_mask_weight': torch.autograd.Variable(input[7]),
        }

        outputs = model(x)

        rpn_cls_loss, rpn_loc_loss, rpn_mask_loss = torch.mean(outputs['losses'][0]), torch.mean(outputs['losses'][1]), torch.mean(outputs['losses'][2])
        mask_iou_mean, mask_iou_at_5, mask_iou_at_7 = torch.mean(outputs['accuracy'][0]), torch.mean(outputs['accuracy'][1]), torch.mean(outputs['accuracy'][2])

        cls_weight, reg_weight, mask_weight = cfg['loss']['weight']

        loss = rpn_cls_loss * cls_weight + rpn_loc_loss * reg_weight + rpn_mask_loss * mask_weight

        optimizer.zero_grad()
        loss.backward()

        if cfg['clip']['split']:
            torch.nn.utils.clip_grad_norm_(model.module.features.parameters(), cfg['clip']['feature'])
            torch.nn.utils.clip_grad_norm_(model.module.rpn_model.parameters(), cfg['clip']['rpn'])
            torch.nn.utils.clip_grad_norm_(model.module.mask_model.parameters(), cfg['clip']['mask'])
            torch.nn.utils.clip_grad_norm_(model.module.refine_model.parameters(), cfg['clip']['mask'])
        else:
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)  # gradient clip

        if is_valid_number(loss.item()):
            optimizer.step()

        siammask_loss = loss.item()

        batch_time = time.time() - end

        avg.update(batch_time=batch_time, rpn_cls_loss=rpn_cls_loss, rpn_loc_loss=rpn_loc_loss,
                   rpn_mask_loss=rpn_mask_loss, siammask_loss=siammask_loss,
                   mask_iou_mean=mask_iou_mean, mask_iou_at_5=mask_iou_at_5, mask_iou_at_7=mask_iou_at_7)

        tb_writer.add_scalar('loss/cls', rpn_cls_loss, tb_index)
        tb_writer.add_scalar('loss/loc', rpn_loc_loss, tb_index)
        tb_writer.add_scalar('loss/mask', rpn_mask_loss, tb_index)
        tb_writer.add_scalar('mask/mIoU', mask_iou_mean, tb_index)
        tb_writer.add_scalar('mask/AP@.5', mask_iou_at_5, tb_index)
        tb_writer.add_scalar('mask/AP@.7', mask_iou_at_7, tb_index)
        end = time.time()

        if (iter + 1) % args.print_freq == 0:
            logger.info('Epoch: [{0}][{1}/{2}] lr: {lr:.6f}\t{batch_time:s}\t{data_time:s}'
                        '\t{rpn_cls_loss:s}\t{rpn_loc_loss:s}\t{rpn_mask_loss:s}\t{siammask_loss:s}'
                        '\t{mask_iou_mean:s}\t{mask_iou_at_5:s}\t{mask_iou_at_7:s}'.format(
                        epoch+1, (iter + 1) % num_per_epoch, num_per_epoch, lr=cur_lr, batch_time=avg.batch_time,
                        data_time=avg.data_time, rpn_cls_loss=avg.rpn_cls_loss, rpn_loc_loss=avg.rpn_loc_loss,
                        rpn_mask_loss=avg.rpn_mask_loss, siammask_loss=avg.siammask_loss, mask_iou_mean=avg.mask_iou_mean,
                        mask_iou_at_5=avg.mask_iou_at_5,mask_iou_at_7=avg.mask_iou_at_7))
            print_speed(iter + 1, avg.batch_time.avg, args.epochs * num_per_epoch)


def save_checkpoint(state, is_best, filename='checkpoint.pth', best_file='model_best.pth'):
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, best_file)


if __name__ == '__main__':
    main()

videoDemo.py

from __future__ import division
import flask
import argparse
import numpy as np
import cv2
from os.path import join, isdir, isfile

from utils.load_helper import load_pretrain
import torch
from utils.config_helper import load_config
from tools.test import *


# 1 创建解析对象
parser = argparse.ArgumentParser(description='PyTorch Tracking Demo')

# 2 添加参数
# 2.1 resume:梗概
parser.add_argument('--resume', default='SiamMask.pth', type=str,
                    metavar='PATH', help='path to latest checkpoint (default: none)')
# 2.2 config配置
parser.add_argument('--config', dest='config', default='config.json',
                    help='hyper-parameter of SiamMask in json format')
# 2.3 处理的图像的序列
parser.add_argument('--base_path', default='../../data/car', help='datasets')
# 2.4 硬件信息
parser.add_argument('--cpu', action='store_true', help='cpu mode')
# 3 解析参数
args = parser.parse_args()


def process_vedio(vedio_path, initRect):
    """
    视频处理
    :param vedio_path:视频路径
    :param initRect: 跟踪目标的初始位置
    :return:
    """

    # 1. 设置设备信息 Setup device
    # 有GPU时选择GPU,否则使用CPU
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # 默认优化运行效率
    torch.backends.cudnn.benchmark = True

    # 2. 模型设置 Setup Model
    # 2.1 将命令行参数解析出来
    cfg = load_config(args)

    # 2.2 custom是构建的网络,否则引用model中的网络结构
    from custom import Custom
    siammask = Custom(anchors=cfg['anchors'])
    # 2.3 判断是否存在模型的权重文件
    if args.resume:
        assert isfile(args.resume), 'Please download {} first.'.format(args.resume)
        siammask = load_pretrain(siammask, args.resume)
    # 在运行推断前,需要调用 model.eval() 函数,以将 dropout 层 和 batch normalization 层设置为评估模式(非训练模式).
    # to(device)将张量复制到GPU上,之后的计算将在GPU上运行
    siammask.eval().to(device)

    # 首帧跟踪目标的位置
    x, y, w, h = initRect
    print(x)
    VeryBig = 999999999  # 用于将视频框调整到最大
    Cap = cv2.VideoCapture(vedio_path)  # 设置读取摄像头
    ret, frame = Cap.read()  # 读取帧
    ims = [frame]  # 把frame放入列表格式的frame, 因为原文是将每帧图片放入列表

    im = frame
    f = 0
    target_pos = np.array([x + w / 2, y + h / 2])
    target_sz = np.array([w, h])
    state = siamese_init(im, target_pos, target_sz, siammask, cfg['hp'])  # init tracker"
    middlepath = "../data/middle.mp4"
    outpath = "../data/output.mp4"
    vediowriter = cv2.VideoWriter(middlepath, cv2.VideoWriter_fourcc('M', 'P', '4', 'V'), 10, (320, 240))
    while (True):
        tic = cv2.getTickCount()
        ret, im = Cap.read()  # 逐个提取frame
        if (ret == False):
            break;
        state = siamese_track(state, im, mask_enable=True, refine_enable=True)  # track
        location = state['ploygon'].flatten()
        mask = state['mask'] > state['p'].seg_thr
        im[:, :, 2] = (mask > 0) * 255 + (mask == 0) * im[:, :, 2]
        cv2.polylines(im, [np.int0(location).reshape((-1, 1, 2))], True, (0, 255, 0), 3)
        vediowriter.write(im)
        cv2.imshow('SiamMask', im)
        key = cv2.waitKey(1)
        if key > 0:
            break

        f = f + 1
    vediowriter.release()

    return


if __name__ == '__main__':
    process_vedio('../data/car.mp4', [162, 121, 28, 25])

SiamMask_master\utils

anchors.py

# --------------------------------------------------------
# anchor处理帮助类
# --------------------------------------------------------
import numpy as np
import math
from utils.bbox_helper import center2corner, corner2center


class Anchors:
    """
    anchors类
    """
    def __init__(self, cfg):
        self.stride = 8  # anchors的范围
        self.ratios = [0.33, 0.5, 1, 2, 3]  # anchors的宽高比
        self.scales = [8]  # anchor的尺度
        self.round_dight = 0  # 兼容python2和python3的数据的舍入
        self.image_center = 0  # 基础锚点的中心在原点
        self.size = 0
        self.anchor_density = 1  # anchor的密度,即每隔几个像素产生锚点

        self.__dict__.update(cfg)

        self.anchor_num = len(self.scales) * len(self.ratios) * (self.anchor_density**2)  # anchor的数目
        self.anchors = None  # 某一像素点的anchor,维度为(anchor_num*4)in single position (anchor_num*4)
        self.all_anchors = None  # 所有像素点的anchor,维度为(2*(4*anchor_num*h*w)):其中包含两种数据格式的锚点表示方法:[x1, y1, x2, y2]和[cx, cy, w, h]:in all position 2*(4*anchor_num*h*w)
        self.generate_anchors()

    def generate_anchors(self):
        """
        生成anchor
        :return:
        """
        # 生成全零数组存储锚点
        self.anchors = np.zeros((self.anchor_num, 4), dtype=np.float32)
        # 生成anchor的大小
        size = self.stride * self.stride
        count = 0
        # 用检测区域的长度除以步长得到生成anchor的点
        anchors_offset = self.stride / self.anchor_density
        # 计算生成anchor的点相对于原点的偏移
        anchors_offset = np.arange(self.anchor_density)*anchors_offset
        anchors_offset = anchors_offset - np.mean(anchors_offset)
        # 利用meshgrid生成x,y方向的偏移值
        x_offsets, y_offsets = np.meshgrid(anchors_offset, anchors_offset)
        # 遍历生成锚点的点,生成对应的anchor
        for x_offset, y_offset in zip(x_offsets.flatten(), y_offsets.flatten()):
            # 遍历宽高比
            for r in self.ratios:
                # 生成anchor的宽高
                if self.round_dight > 0:
                    ws = round(math.sqrt(size*1. / r), self.round_dight)
                    hs = round(ws * r, self.round_dight)
                else:
                    ws = int(math.sqrt(size*1. / r))
                    hs = int(ws * r)
                # 根据anchor的尺寸生成anchor
                for s in self.scales:
                    w = ws * s
                    h = hs * s
                    self.anchors[count][:] = [-w*0.5+x_offset, -h*0.5+y_offset, w*0.5+x_offset, h*0.5+y_offset][:]
                    count += 1


    def generate_all_anchors(self, im_c, size):
        """
        生成整幅图像的anchors
        :param im_c:图像的中心点
        :param size:图像的尺寸
        :return:
        """
        if self.image_center == im_c and self.size == size:
            return False
        # 更新config中的内容
        self.image_center = im_c
        self.size = size
        # anchor0 的xy 坐标,即 x 和 y 对称。
        a0x = im_c - size // 2 * self.stride
        # 生成anchor0的坐标
        ori = np.array([a0x] * 4, dtype=np.float32)
        # 以图像中心点为中心点的anchor
        zero_anchors = self.anchors + ori
        # 获取anchor0的坐标
        x1 = zero_anchors[:, 0]
        y1 = zero_anchors[:, 1]
        x2 = zero_anchors[:, 2]
        y2 = zero_anchors[:, 3]

        x1, y1, x2, y2 = map(lambda x: x.reshape(self.anchor_num, 1, 1), [x1, y1, x2, y2])
        cx, cy, w, h = corner2center([x1, y1, x2, y2])
        # disp_x是[1, 1, size],disp_y是[1, size, 1]
        disp_x = np.arange(0, size).reshape(1, 1, -1) * self.stride
        disp_y = np.arange(0, size).reshape(1, -1, 1) * self.stride
        # 得到整幅图像中anchor中心点的坐标
        cx = cx + disp_x
        cy = cy + disp_y

        # 通过广播生成整幅图像的anchor broadcast
        zero = np.zeros((self.anchor_num, size, size), dtype=np.float32)
        cx, cy, w, h = map(lambda x: x + zero, [cx, cy, w, h])
        x1, y1, x2, y2 = center2corner([cx, cy, w, h])
        # 以中心点坐标,宽高和左上角、右下角坐标两种方式存储anchors
        self.all_anchors = np.stack([x1, y1, x2, y2]), np.stack([cx, cy, w, h])
        return True


if __name__ == '__main__':
    anchors = Anchors(cfg={'stride':16, 'anchor_density': 2})
    anchors.generate_all_anchors(im_c=255//2, size=(255-127)//16+1+8)
    print(anchors.all_anchors)
    # a = 1

average_meter_helper.py

# --------------------------------------------------------
# 计算和存储指标数据
# --------------------------------------------------------
import numpy as np


class Meter(object):
    "指标数据"
    def __init__(self, name, val, avg):
        # 名称
        self.name = name
        # 值
        self.val = val
        # 平均值
        self.avg = avg

    def __repr__(self):
        return "{name}: {val:.6f} ({avg:.6f})".format(
            name=self.name, val=self.val, avg=self.avg
        )

    def __format__(self, *tuples, **kwargs):
        return self.__repr__()


class AverageMeter(object):
    """计算平均值和当前值并进行存储"""
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        # 重置
        self.val = {}
        self.sum = {}
        self.count = {}

    def update(self, batch=1, **kwargs):
        # 参数更新
        val = {}
        for k in kwargs:
            # 遍历参数
            val[k] = kwargs[k] / float(batch)
        self.val.update(val)
        for k in kwargs:
            # 计算sum和count
            if k not in self.sum:
                self.sum[k] = 0
                self.count[k] = 0
            self.sum[k] += kwargs[k]
            self.count[k] += batch

    def __repr__(self):
        s = ''
        for k in self.sum:
            s += self.format_str(k)
        return s

    def format_str(self, attr):
        # 格式输出
        return "{name}: {val:.6f} ({avg:.6f}) ".format(
                    name=attr,
                    val=float(self.val[attr]),
                    avg=float(self.sum[attr]) / self.count[attr])

    def __getattr__(self, attr):
        if attr in self.__dict__:
            return super(AverageMeter, self).__getattr__(attr)
        if attr not in self.sum:
            # logger.warn("invalid key '{}'".format(attr))
            print("invalid key '{}'".format(attr))
            return Meter(attr, 0, 0)
        return Meter(attr, self.val[attr], self.avg(attr))

    def avg(self, attr):
        return float(self.sum[attr]) / self.count[attr]


class IouMeter(object):
    """计算和存储IOU指标数据"""
    def __init__(self, thrs, sz):
        "初始化"
        self.sz = sz
        self.iou = np.zeros((sz, len(thrs)), dtype=np.float32)
        self.thrs = thrs
        self.reset()

    def reset(self):
        # 重置
        self.iou.fill(0.)
        self.n = 0

    def add(self, output, target):
        '添加交并比'
        if self.n >= len(self.iou):
            return
        target, output = target.squeeze(), output.squeeze()
        # 计算交并比
        for i, thr in enumerate(self.thrs):
            pred = output > thr
            mask_sum = (pred == 1).astype(np.uint8) + (target > 0).astype(np.uint8)
            # 并
            intxn = np.sum(mask_sum == 2)
            # 交
            union = np.sum(mask_sum > 0)
            if union > 0:
                # 交并比
                self.iou[self.n, i] = intxn / union
            elif union == 0 and intxn == 0:
                # 交并比为1
                self.iou[self.n, i] = 1
        self.n += 1

    def value(self, s):
        nb = max(int(np.sum(self.iou > 0)), 1)
        iou = self.iou[:nb]

        def is_number(s):
            "判断是否为数值"
            try:
                float(s)
                return True
            except ValueError:
                return False
        if s == 'mean':
            # 均值
            res = np.mean(iou, axis=0)
        elif s == 'median':
            # 中位数
            res = np.median(iou, axis=0)
        elif is_number(s):
            # 均值
            res = np.sum(iou > float(s), axis=0) / float(nb)
        return res


if __name__ == '__main__':
    avg = AverageMeter()
    avg.update(time=1.1, accuracy=.99)
    avg.update(time=1.0, accuracy=.90)

    print(avg)
    print(avg.sum)
    print(avg.time)
    print(avg.time.avg)
    print(avg.time.val)
    print(avg.SS)



bbox_helper.py

# --------------------------------------------------------
# 矩形框处理帮助
# --------------------------------------------------------
import numpy as np
from collections import namedtuple
# 定义类型Corner: 左上角坐标和右下角坐标
Corner = namedtuple('Corner', 'x1 y1 x2 y2')
BBox = Corner
# 定义类型Center:中心点坐标和宽高
Center = namedtuple('Center', 'x y w h')


def corner2center(corner):
    """
    左上角右下角坐标转换为中心坐标,宽高
    :param corner: Corner or np.array 4*N
    :return: Center or 4 np.array N
    """
    # 判断输入数据是否为Corner
    if isinstance(corner, Corner):
        # 获取坐标数据
        x1, y1, x2, y2 = corner
        # 计算中心点坐标和宽高
        return Center((x1 + x2) * 0.5, (y1 + y2) * 0.5, (x2 - x1), (y2 - y1))
    else:
        # 获取坐标
        x1, y1, x2, y2 = corner[0], corner[1], corner[2], corner[3]
        # 计算中心点坐标
        x = (x1 + x2) * 0.5
        y = (y1 + y2) * 0.5
        # 计算宽高
        w = x2 - x1
        h = y2 - y1
        return x, y, w, h


def center2corner(center):
    """
    中心坐标,宽高转换为左上角右下角坐标
    :param center: Center or np.array 4*N
    :return: Corner or np.array 4*N
    """
    # 判断数据是否为Center
    if isinstance(center, Center):
        # 获取坐标数据和宽高
        x, y, w, h = center
        # 计算Corner
        return Corner(x - w * 0.5, y - h * 0.5, x + w * 0.5, y + h * 0.5)
    else:
        # 获取数据
        x, y, w, h = center[0], center[1], center[2], center[3]
        # 左上角坐标
        x1 = x - w * 0.5
        y1 = y - h * 0.5
        # 右下角坐标
        x2 = x + w * 0.5
        y2 = y + h * 0.5
        return x1, y1, x2, y2


def cxy_wh_2_rect(pos, sz):
    """
    转换矩形框的表示方式
    :param pos: 矩形框中心点坐标
    :param sz: 矩形框大小:宽高
    :return: 矩形框的左上角坐标,宽,高
    """
    return np.array([pos[0]-sz[0]/2, pos[1]-sz[1]/2, sz[0], sz[1]])  # 0-index


def get_axis_aligned_bbox(region):
    """
    将目标区域其最小外接矩形的形式:中心点坐标和宽,高的形式
    :param region:
    :return:中心点坐标,宽,高
    """
    nv = region.size
    # 若region是四角坐标,可能不平行于图像
    if nv == 8:
        # 计算中心点坐标
        cx = np.mean(region[0::2])
        cy = np.mean(region[1::2])
        # 计算外接矩形的左上角坐标和右下角坐标
        x1 = min(region[0::2])
        x2 = max(region[0::2])
        y1 = min(region[1::2])
        y2 = max(region[1::2])
        # 求L2范数
        # 平行四边形面积
        A1 = np.linalg.norm(region[0:2] - region[2:4]) * np.linalg.norm(region[2:4] - region[4:6])
        # 外接矩形面积
        A2 = (x2 - x1) * (y2 - y1)
        s = np.sqrt(A1 / A2)
        # 求宽和高
        w = s * (x2 - x1) + 1
        h = s * (y2 - y1) + 1
    # region 是左上角坐标和宽高
    else:
        x = region[0]
        y = region[1]
        w = region[2]
        h = region[3]
        # 中心点坐标
        cx = x+w/2
        cy = y+h/2

    return cx, cy, w, h


def aug_apply(bbox, param, shape, inv=False, rd=False):
    """
    对矩形进行增强 apply augmentation
    :param bbox: original bbox in image
    :param param: augmentation param, shift/scale
    :param shape: image shape, h, w, (c)
    :param inv: inverse
    :param rd: round bbox
    :return: bbox(, param)
        bbox: augmented bbox
        param: real augmentation param
    """
    if not inv:
        # 获取中心坐标
        center = corner2center(bbox)
        original_center = center

        real_param = {}
        # 矩形缩放
        if 'scale' in param:
            # 获取缩放比例
            scale_x, scale_y = param['scale']
            imh, imw = shape[:2]
            # 获取宽高
            h, w = center.h, center.w
            # 计算比例
            scale_x = min(scale_x, float(imw) / w)
            scale_y = min(scale_y, float(imh) / h)

            # center.w *= scale_x
            # center.h *= scale_y
            # 计算中心
            center = Center(center.x, center.y, center.w * scale_x, center.h * scale_y)
        # 获取目标框(x1, y1, x2, y2 )
        bbox = center2corner(center)
        # 矩形平移
        if 'shift' in param:
            tx, ty = param['shift']
            x1, y1, x2, y2 = bbox
            imh, imw = shape[:2]
            # 获取平移距离
            tx = max(-x1, min(imw - 1 - x2, tx))
            ty = max(-y1, min(imh - 1 - y2, ty))

            bbox = Corner(x1 + tx, y1 + ty, x2 + tx, y2 + ty)

        if rd:
            bbox = Corner(*map(round, bbox))

        current_center = corner2center(bbox)
        # 缩放和平移参数
        real_param['scale'] = current_center.w / original_center.w, current_center.h / original_center.h
        real_param['shift'] = current_center.x - original_center.x, current_center.y - original_center.y

        return bbox, real_param
    else:
        # 矩形框缩放
        if 'scale' in param:
            scale_x, scale_y = param['scale']
        else:
            scale_x, scale_y = 1., 1.
        # 平移
        if 'shift' in param:
            tx, ty = param['shift']
        else:
            tx, ty = 0, 0
        # 中心点坐标
        center = corner2center(bbox)
        center = Center(center.x - tx, center.y - ty, center.w / scale_x, center.h / scale_y)
        return center2corner(center)


def IoU(rect1, rect2):
    # 计算IOU
    x1, y1, x2, y2 = rect1[0], rect1[1], rect1[2], rect1[3]
    tx1, ty1, tx2, ty2 = rect2[0], rect2[1], rect2[2], rect2[3]
    # 获取坐标的最大值和最小值
    xx1 = np.maximum(tx1, x1)
    yy1 = np.maximum(ty1, y1)
    xx2 = np.minimum(tx2, x2)
    yy2 = np.minimum(ty2, y2)
    # 获取交的宽高
    ww = np.maximum(0, xx2 - xx1)
    hh = np.maximum(0, yy2 - yy1)
    # rect1的面积
    area = (x2-x1) * (y2-y1)
    # rect2的面积
    target_a = (tx2-tx1) * (ty2 - ty1)
    # 交
    inter = ww * hh
    # 交并比
    overlap = inter / (area + target_a - inter)
    return overlap

benchmark_helper.py

# --------------------------------------------------------
# 数据集加载帮助类
# --------------------------------------------------------
from os.path import join, realpath, dirname, exists, isdir
from os import listdir
import logging
import glob
import numpy as np
import json
from collections import OrderedDict


def get_dataset_zoo():
    """
    对data中的数据进行检测
    :return:
    """
    root = realpath(join(dirname(__file__), '../data'))
    zoos = listdir(root)
    # 检测条件
    def valid(x):
        y = join(root, x)
        if not isdir(y): return False

        return exists(join(y, 'list.txt')) \
               or exists(join(y, 'train', 'meta.json')) \
               or exists(join(y, 'ImageSets', '2016', 'val.txt')) \
               or exists(join(y, 'ImageSets', '2017', 'test-dev.txt'))
    # 检测data中的符合条件的返回
    zoos = list(filter(valid, zoos))
    return zoos


dataset_zoo = get_dataset_zoo()

# 加载数据集
def load_dataset(dataset):
    # OrderedDict可对字典中的数据进行排序
    info = OrderedDict()
    # VOT数据
    if 'VOT' in dataset:
        # 基础路径
        base_path = join(realpath(dirname(__file__)), '../data', dataset)
        if not exists(base_path):
            logging.error("Please download test dataset!!!")
            exit()
        # 获取路径列表:指明视频的路径
        list_path = join(base_path, 'list.txt')
        # 获取数据
        with open(list_path) as f:
            videos = [v.strip() for v in f.readlines()]
        # 遍历数据
        for video in videos:
            # 视频路径
            video_path = join(base_path, video)
            image_path = join(video_path, '*.jpg')
            image_files = sorted(glob.glob(image_path))
            if len(image_files) == 0:  # VOT2018
                image_path = join(video_path, 'color', '*.jpg')
                image_files = sorted(glob.glob(image_path))
            gt_path = join(video_path, 'groundtruth.txt')
            gt = np.loadtxt(gt_path, delimiter=',').astype(np.float64)
            if gt.shape[1] == 4:
                gt = np.column_stack((gt[:, 0], gt[:, 1], gt[:, 0], gt[:, 1] + gt[:, 3]-1,
                                      gt[:, 0] + gt[:, 2]-1, gt[:, 1] + gt[:, 3]-1, gt[:, 0] + gt[:, 2]-1, gt[:, 1]))
            info[video] = {'image_files': image_files, 'gt': gt, 'name': video}
    # DAVIS
    elif 'DAVIS' in dataset and 'TEST' not in dataset:
        base_path = join(realpath(dirname(__file__)), '../data', 'DAVIS')
        list_path = join(realpath(dirname(__file__)), '../data', 'DAVIS', 'ImageSets', dataset[-4:], 'val.txt')
        with open(list_path) as f:
            videos = [v.strip() for v in f.readlines()]
        for video in videos:
            info[video] = {}
            info[video]['anno_files'] = sorted(glob.glob(join(base_path, 'Annotations/480p', video, '*.png')))
            info[video]['image_files'] = sorted(glob.glob(join(base_path, 'JPEGImages/480p', video, '*.jpg')))
            info[video]['name'] = video
    # ytb_vos数据
    elif 'ytb_vos' in dataset:
        base_path = join(realpath(dirname(__file__)), '../data', 'ytb_vos', 'valid')
        json_path = join(realpath(dirname(__file__)), '../data', 'ytb_vos', 'valid', 'meta.json')
        meta = json.load(open(json_path, 'r'))
        meta = meta['videos']
        info = dict()
        for v in meta.keys():
            objects = meta[v]['objects']
            frames = []
            anno_frames = []
            info[v] = dict()
            for obj in objects:
                frames += objects[obj]['frames']
                anno_frames += [objects[obj]['frames'][0]]
            frames = sorted(np.unique(frames))
            info[v]['anno_files'] = [join(base_path, 'Annotations', v, im_f+'.png') for im_f in frames]
            info[v]['anno_init_files'] = [join(base_path, 'Annotations', v, im_f + '.png') for im_f in anno_frames]
            info[v]['image_files'] = [join(base_path, 'JPEGImages', v, im_f+'.jpg') for im_f in frames]
            info[v]['name'] = v

            info[v]['start_frame'] = dict()
            info[v]['end_frame'] = dict()
            for obj in objects:
                start_file = objects[obj]['frames'][0]
                end_file = objects[obj]['frames'][-1]
                info[v]['start_frame'][obj] = frames.index(start_file)
                info[v]['end_frame'][obj] = frames.index(end_file)
    # 测试数据
    elif 'TEST' in dataset:
        base_path = join(realpath(dirname(__file__)), '../data', 'DAVIS2017TEST')
        list_path = join(realpath(dirname(__file__)), '../data', 'DAVIS2017TEST', 'ImageSets', '2017', 'test-dev.txt')
        with open(list_path) as f:
            videos = [v.strip() for v in f.readlines()]
        for video in videos:
            info[video] = {}
            info[video]['anno_files'] = sorted(glob.glob(join(base_path, 'Annotations/480p', video, '*.png')))
            info[video]['image_files'] = sorted(glob.glob(join(base_path, 'JPEGImages/480p', video, '*.jpg')))
            info[video]['name'] = video
    else:
        logging.error('Not support')
        exit()
    return info

config_helper.py

# --------------------------------------------------------
# 配置文件处理帮助
# --------------------------------------------------------
import json
from os.path import exists


def proccess_loss(cfg):
    """
    解析配置文件中的loss函数
    :param cfg:
    :return:
    """
    # 回归
    if 'reg' not in cfg:
        # 默认为L1Loss
        cfg['reg'] = {'loss': 'L1Loss'}
    else:
        if 'loss' not in cfg['reg']:
            cfg['reg']['loss'] = 'L1Loss'
    # 分类
    if 'cls' not in cfg:
        cfg['cls'] = {'split': True}
    # cls, reg, mask损失的比重
    cfg['weight'] = cfg.get('weight', [1, 1, 36])


def add_default(conf, default):
    # 默认设置
    default.update(conf)
    return default


def load_config(args):
    """
    加载命令行中指定的配置文件中的信息
    :param args:命令行解析结果
    :return:json配置文件中的信息
    """
    # 断言命令行中是否包含config,若包含则对其进行解析
    assert exists(args.config), '"{}" not exists'.format(args.config)
    config = json.load(open(args.config))

    # deal with network 网络结构
    if 'network' not in config:
        print('Warning: network lost in config. This will be error in next version')

        config['network'] = {}

        if not args.arch:
            raise Exception('no arch provided')
    args.arch = config['network']['arch']

    # deal with loss 损失函数
    if 'loss' not in config:
        config['loss'] = {}

    proccess_loss(config['loss'])

    # deal with lr 学习率
    if 'lr' not in config:
        config['lr'] = {}
    default = {
            'feature_lr_mult': 1.0,
            'rpn_lr_mult': 1.0,
            'mask_lr_mult': 1.0,
            'type': 'log',
            'start_lr': 0.03
            }
    default.update(config['lr'])
    config['lr'] = default

    # clip 命令行中的参数,是否进行裁剪
    if 'clip' in config or 'clip' in args.__dict__:
        if 'clip' not in config:
            config['clip'] = {}
        config['clip'] = add_default(config['clip'],
                {'feature': args.clip, 'rpn': args.clip, 'split': False})
        if config['clip']['feature'] != config['clip']['rpn']:
            config['clip']['split'] = True
        if not config['clip']['split']:
            args.clip = config['clip']['feature']

    return config

load_helper.py

# --------------------------------------------------------
# 模型加载帮助类
# --------------------------------------------------------

import torch
import logging
logger = logging.getLogger('global')


def check_keys(model, pretrained_state_dict):
    "模型检查"
    # 预训练模型中的keys
    ckpt_keys = set(pretrained_state_dict.keys())
    # 原始模型中的keys
    model_keys = set(model.state_dict().keys())
    # 预训练模型和原始模型包含的keys
    used_pretrained_keys = model_keys & ckpt_keys
    # 只在预训练模型中的keys
    unused_pretrained_keys = ckpt_keys - model_keys
    # 只在原始模型中的keys
    missing_keys = model_keys - ckpt_keys
    # 预训练模型丢失的keys大于0
    if len(missing_keys) > 0:
        logger.info('[Warning] missing keys: {}'.format(missing_keys))
        logger.info('missing keys:{}'.format(len(missing_keys)))
    # 只在原始模型中的keys
    if len(unused_pretrained_keys) > 0:
        logger.info('[Warning] unused_pretrained_keys: {}'.format(unused_pretrained_keys))
        logger.info('unused checkpoint keys:{}'.format(len(unused_pretrained_keys)))
    logger.info('used keys:{}'.format(len(used_pretrained_keys)))
    assert len(used_pretrained_keys) > 0, 'load NONE from pretrained checkpoint'
    return True


def remove_prefix(state_dict, prefix):
    ''' Old style model is stored with all names of parameters share common prefix 'module.' '''
    logger.info('remove prefix \'{}\''.format(prefix))
    f = lambda x: x.split(prefix, 1)[-1] if x.startswith(prefix) else x
    return {f(key): value for key, value in state_dict.items()}


# 加载模型的权重文件
def load_pretrain(model, pretrained_path):
    logger.info('load pretrained model from {}'.format(pretrained_path))
    # 加载预训练模型
    if not torch.cuda.is_available():
        # CPU
        pretrained_dict = torch.load(pretrained_path, map_location=lambda storage, loc: storage)
    else:
        # GPU
        device = torch.cuda.current_device()
        pretrained_dict = torch.load(pretrained_path, map_location=lambda storage, loc: storage.cuda(device))
    # 去除前置网络
    if "state_dict" in pretrained_dict.keys():
        pretrained_dict = remove_prefix(pretrained_dict['state_dict'], 'module.')
    else:
        pretrained_dict = remove_prefix(pretrained_dict, 'module.')

    try:
        # 模型检测
        check_keys(model, pretrained_dict)
    except:
        logger.info('[Warning]: using pretrain as features. Adding "features." as prefix')
        new_dict = {}
        for k, v in pretrained_dict.items():
            k = 'features.' + k
            new_dict[k] = v
        pretrained_dict = new_dict
        check_keys(model, pretrained_dict)
    # 加载模型
    model.load_state_dict(pretrained_dict, strict=False)
    return model


def restore_from(model, optimizer, ckpt_path):
    "模型恢复"
    logger.info('restore from {}'.format(ckpt_path))
    device = torch.cuda.current_device()
    # 加载权重文件
    ckpt = torch.load(ckpt_path, map_location=lambda storage, loc: storage.cuda(device))
    epoch = ckpt['epoch']
    best_acc = ckpt['best_acc']
    arch = ckpt['arch']
    ckpt_model_dict = remove_prefix(ckpt['state_dict'], 'module.')
    check_keys(model, ckpt_model_dict)
    # 加载模型
    model.load_state_dict(ckpt_model_dict, strict=False)
    # 模型检查
    check_keys(optimizer, ckpt['optimizer'])
    # 加载优化器
    optimizer.load_state_dict(ckpt['optimizer'])
    return model, optimizer, epoch, best_acc, arch

log_helper.py

# --------------------------------------------------------
# 日志帮助类
# --------------------------------------------------------
from __future__ import division

import os
import logging
import sys
import math

# 确定日志输出的文件,可定位到exe,pyc,pyo,py文件
if hasattr(sys, 'frozen'):  # support for py2exe
    _srcfile = "logging%s__init__%s" % (os.sep, __file__[-4:])
elif __file__[-4:].lower() in ['.pyc', '.pyo']:
    _srcfile = __file__[:-4] + '.py'
else:
    _srcfile = __file__
_srcfile = os.path.normcase(_srcfile)

# 创建日志的集合
logs = set()


class Filter:
    def __init__(self, flag):
        self.flag = flag

    def filter(self, x): return self.flag


class Dummy:
    def __init__(self, *arg, **kwargs):
        pass

    def __getattr__(self, arg):
        def dummy(*args, **kwargs): pass
        return dummy

# 多GPU并行运行时进行日志记录
def get_format(logger, level):
    if 'SLURM_PROCID' in os.environ:
        rank = int(os.environ['SLURM_PROCID'])

        if level == logging.INFO:
            logger.addFilter(Filter(rank == 0))
    else:
        rank = 0
    format_str = '[%(asctime)s-rk{}-%(filename)s#%(lineno)3d] %(message)s'.format(rank)
    formatter = logging.Formatter(format_str)
    return formatter


def get_format_custom(logger, level):
    if 'SLURM_PROCID' in os.environ:
        rank = int(os.environ['SLURM_PROCID'])
        if level == logging.INFO:
            logger.addFilter(Filter(rank == 0))
    else:
        rank = 0
    format_str = '[%(asctime)s-rk{}-%(message)s'.format(rank)
    formatter = logging.Formatter(format_str)
    return formatter


def init_log(name, level = logging.INFO, format_func=get_format):
    if (name, level) in logs: return
    logs.add((name, level))
    logger = logging.getLogger(name)
    logger.setLevel(level)
    ch = logging.StreamHandler()
    ch.setLevel(level)
    formatter = format_func(logger, level)
    ch.setFormatter(formatter)
    logger.addHandler(ch)
    return logger


def add_file_handler(name, log_file, level = logging.INFO):
    # 初始化logger对象
    logger = logging.getLogger(name)
    # 将日志输出到文件中
    fh = logging.FileHandler(log_file)
    # 设置日志的格式
    fh.setFormatter(get_format(logger, level))
    # 处理日志
    logger.addHandler(fh)


init_log('global')


def print_speed(i, i_time, n):
    """print_speed(index, index_time, total_iteration)"""
    logger = logging.getLogger('global')
    average_time = i_time
    remaining_time = (n - i) * average_time
    remaining_day = math.floor(remaining_time / 86400)
    remaining_hour = math.floor(remaining_time / 3600 - remaining_day * 24)
    remaining_min = math.floor(remaining_time / 60 - remaining_day * 1440 - remaining_hour * 60)
    logger.info('Progress: %d / %d [%d%%], Speed: %.3f s/iter, ETA %d:%02d:%02d (D:H:M)\n' % (i, n, i/n*100, average_time, remaining_day, remaining_hour, remaining_min))



def find_caller():
    """
    根据栈帧的信息获取日志输出的文件,函数及行号
    :return:
    """
    def current_frame():
        """
        Return the frame object for the caller's stack frame
        返回调用函数caller栈帧位置
        :return:
        """
        try:
            raise Exception
        except:
            return sys.exc_info()[2].tb_frame.f_back

    f = current_frame()
    if f is not None:
        f = f.f_back
    rv = "(unknown file)", 0, "(unknown function)"
    while hasattr(f, "f_code"):
        co = f.f_code
        filename = os.path.normcase(co.co_filename)
        rv = (co.co_filename, f.f_lineno, co.co_name)
        if filename == _srcfile:
            f = f.f_back
            continue
        break
    rv = list(rv)
    rv[0] = os.path.basename(rv[0])
    return rv

# 用于模型训练时的日志输出
class LogOnce:
    def __init__(self):
        self.logged = set()
        self.logger = init_log('log_once', format_func=get_format_custom)

    def log(self, strings):
        fn, lineno, caller = find_caller()
        key = (fn, lineno, caller, strings)
        if key in self.logged:
            return
        self.logged.add(key)
        message = "{filename:s}<{caller}>#{lineno:3d}] {strings}".format(filename=fn, lineno=lineno, strings=strings, caller=caller)
        self.logger.info(message)


once_logger = LogOnce()


def log_once(strings):
    once_logger.log(strings)

lr_helper.py

# --------------------------------------------------------
# 学习率lr更新
# --------------------------------------------------------
from __future__ import division
import numpy as np
import math
from torch.optim.lr_scheduler import _LRScheduler
import matplotlib.pyplot as plt


class LRScheduler(_LRScheduler):
    """
    学习率更新策略
    """
    def __init__(self, optimizer, last_epoch=-1):
        # 若不存在"lr_spaces"返回异常,lr_spaces是lr更新序列
        if 'lr_spaces' not in self.__dict__:
            raise Exception('lr_spaces must be set in "LRSchduler"')
        super(LRScheduler, self).__init__(optimizer, last_epoch)

    def get_cur_lr(self):
        """
        获取当前epoch的学习率
        :return:
        """
        return self.lr_spaces[self.last_epoch]

    def get_lr(self):
        """
        定义学习率的更新策略
        :return:
        """
        epoch = self.last_epoch
        # 返回当前epoch优化器中学习率
        return [self.lr_spaces[epoch] * pg['initial_lr'] / self.start_lr for pg in self.optimizer.param_groups]

    def __repr__(self):
        """
        返回学习率更新的序列
        :return:
        """
        return "({}) lr spaces: \n{}".format(self.__class__.__name__, self.lr_spaces)


class LogScheduler(LRScheduler):
    """
    指数式更新学习率
    """
    def __init__(self, optimizer, start_lr=0.03, end_lr=5e-4, epochs=50, last_epoch=-1, **kwargs):
        self.start_lr = start_lr
        self.end_lr = end_lr
        self.epochs = epochs
        # 指明初始值和终值,依据epochs生成学习率
        self.lr_spaces = np.logspace(math.log10(start_lr), math.log10(end_lr), epochs)

        super(LogScheduler, self).__init__(optimizer, last_epoch)


class StepScheduler(LRScheduler):
    """
    步进式的更新学习率
    """
    def __init__(self, optimizer, start_lr=0.01, end_lr=None, step=10, mult=0.1, epochs=50, last_epoch=-1, **kwargs):
        """
        初始化
        :param optimizer: 优化器
        :param start_lr: 初始lr
        :param end_lr: 终值lr
        :param step: 经过step个epoch更新学习率
        :param mult: 更新参数gamma
        :param epochs:
        :param last_epoch: 起始epoch
        :param kwargs:
        """

        # 若end_lr不为None
        if end_lr is not None:
            if start_lr is None:
                # 根据end_lr求解start_lr,multshi
                start_lr = end_lr / (mult ** (epochs // step))
            else:  # for warm up policy
                # 计算mult
                mult = math.pow(end_lr/start_lr, 1. / (epochs // step))
        self.start_lr = start_lr
        # 得到学习率的序列
        self.lr_spaces = self.start_lr * (mult**(np.arange(epochs) // step))
        self.mult = mult
        # 没过step个epoch更新学习率
        self._step = step

        super(StepScheduler, self).__init__(optimizer, last_epoch)


class MultiStepScheduler(LRScheduler):
    """
    多步长更新学习率
    """
    def __init__(self, optimizer, start_lr=0.01, end_lr=None, steps=[10,20,30,40], mult=0.5, epochs=50, last_epoch=-1, **kwargs):
        """
        :param optimizer: 优化器
        :param start_lr: 起始学习率
        :param end_lr: 终值学习率
        :param steps: 学习率进行更新的步长序列
        :param mult: 更新参数
        :param epochs:
        :param last_epoch: 起始epoch
        :param kwargs:
        """
        if end_lr is not None:
            if start_lr is None:
                # 计算start_lr
                start_lr = end_lr / (mult ** (len(steps)))
            else:
                # 计算mult
                mult = math.pow(end_lr/start_lr, 1. / len(steps))
        self.start_lr = start_lr
        # 获取lr_spaces
        self.lr_spaces = self._build_lr(start_lr, steps, mult, epochs)
        self.mult = mult
        self.steps = steps

        super(MultiStepScheduler, self).__init__(optimizer, last_epoch)

    def _build_lr(self, start_lr, steps, mult, epochs):
        """
        计算学习率列表
        :param start_lr:
        :param steps:
        :param mult:
        :param epochs:
        :return:
        """
        lr = [0] * epochs
        lr[0] = start_lr
        for i in range(1, epochs):
            lr[i] = lr[i-1]
            # 若i在steps中则修改学习率,否则学习率不变
            if i in steps:
                lr[i] *= mult
        return np.array(lr, dtype=np.float32)


class LinearStepScheduler(LRScheduler):
    """
    线性更新学习率
    """
    def __init__(self, optimizer, start_lr=0.01, end_lr=0.005, epochs=50, last_epoch=-1, **kwargs):
        self.start_lr = start_lr
        self.end_lr = end_lr
        # 生成学习率序列
        self.lr_spaces = np.linspace(start_lr, end_lr, epochs)

        super(LinearStepScheduler, self).__init__(optimizer, last_epoch)


class CosStepScheduler(LRScheduler):
    """
    cos式的更新学习率
    """
    def __init__(self, optimizer, start_lr=0.01, end_lr=0.005, epochs=50, last_epoch=-1, **kwargs):
        self.start_lr = start_lr
        self.end_lr = end_lr
        # 获取学习率
        self.lr_spaces = self._build_lr(start_lr, end_lr, epochs)

        super(CosStepScheduler, self).__init__(optimizer, last_epoch)

    def _build_lr(self, start_lr, end_lr, epochs):
        """
        创建学习率列表
        :param start_lr: 开始学习率
        :param end_lr: 终值学习率
        :param epochs: epoch
        :return:
        """
        # 将epochs转换为浮点型数据
        index = np.arange(epochs).astype(np.float32)
        # 更新学习率
        lr = end_lr + (start_lr - end_lr) * (1. + np.cos(index * np.pi/ epochs)) * 0.5
        return lr.astype(np.float32)


class WarmUPScheduler(LRScheduler):
    """
    将不同的学习率更新方式进行连接
    """
    def __init__(self, optimizer, warmup, normal, epochs=50, last_epoch=-1):
        warmup = warmup.lr_spaces # [::-1]
        normal = normal.lr_spaces
        # 将两种更新方式进行连接
        self.lr_spaces = np.concatenate([warmup, normal])
        self.start_lr = normal[0]

        super(WarmUPScheduler, self).__init__(optimizer, last_epoch)

# 学习率更新方式集合
LRs = {
    'log': LogScheduler,
    'step': StepScheduler,
    'multi-step': MultiStepScheduler,
    'linear': LinearStepScheduler,
    'cos': CosStepScheduler}


def _build_lr_scheduler(optimizer, cfg, epochs=50, last_epoch=-1):
    """
    根据配置信息完成学习率更新
    :param optimizer:
    :param cfg:
    :param epochs:
    :param last_epoch:
    :return:
    """
    # 默认为按LOG方式进行更新
    if 'type' not in cfg:
        cfg['type'] = 'log'
    # 若更新方式不在LRs中则返回异常
    if cfg['type'] not in LRs:
        raise Exception('Unknown type of LR Scheduler "%s"'%cfg['type'])
    # 返回学习率结果
    return LRs[cfg['type']](optimizer, last_epoch=last_epoch, epochs=epochs, **cfg)


def _build_warm_up_scheduler(optimizer, cfg, epochs=50, last_epoch=-1):
    """
    根据配置信息,按照warm_up方式完成学习率更新
    :param optimizer:
    :param cfg:
    :param epochs:
    :param last_epoch:
    :return:
    """
    # 获取第一种更新方式的epoch
    warmup_epoch = cfg['warmup']['epoch']
    # 构建学习率更新列表
    # 将学习率增加
    sc1 = _build_lr_scheduler(optimizer, cfg['warmup'], warmup_epoch, last_epoch)
    # 学习率下降
    sc2 = _build_lr_scheduler(optimizer, cfg, epochs - warmup_epoch, last_epoch)
    # 返回连接后的结果
    return WarmUPScheduler(optimizer, sc1, sc2, epochs, last_epoch)


def build_lr_scheduler(optimizer, cfg, epochs=50, last_epoch=-1):
    """
    将上述两种方法进行整合
    :param optimizer:
    :param cfg:
    :param epochs:
    :param last_epoch:
    :return:
    """
    # 若配置信息中含有"warmup"
    if 'warmup' in cfg:
        return _build_warm_up_scheduler(optimizer, cfg, epochs, last_epoch)
    else:
    # 否则
        return _build_lr_scheduler(optimizer, cfg, epochs, last_epoch)


if __name__ == '__main__':

    import torch.nn as nn
    from torch.optim import SGD
    # 模型搭建
    class Net(nn.Module):
        def __init__(self):
            super(Net, self).__init__()
            self.conv = nn.Conv2d(10, 10, kernel_size=3)
    # 模型实例化
    net = Net().parameters()
    # 优化器
    optimizer = SGD(net, lr=0.01)

    # step更新机制
    step = {
            'type': 'step',
            'start_lr': 0.01,
            'step': 10,
            'mult': 0.1
            }
    lr = build_lr_scheduler(optimizer, step)
    print("test1")
    print(lr)
    plt.plot(lr.lr_spaces)
    plt.grid()
    plt.title("StepScheduler")
    plt.xlabel("epochs")
    plt.ylabel("lr")
    plt.show()

    # Linear更新机制
    step = {
        'type': 'linear',
        'start_lr': 0.01,
        'end_lr':0.005
    }
    lr = build_lr_scheduler(optimizer, step)
    print("test1")
    print(lr)
    plt.plot(lr.lr_spaces)
    plt.grid()
    plt.title("linear")
    plt.xlabel("epochs")
    plt.ylabel("lr")
    plt.show()

    # log更新机制
    log = {
            'type': 'log',
            'start_lr': 0.03,
            'end_lr': 5e-4,
            }
    lr = build_lr_scheduler(optimizer, log)

    print(lr)
    plt.plot(lr.lr_spaces)
    plt.grid()
    plt.title("logScheduler")
    plt.xlabel("epochs")
    plt.ylabel("lr")
    plt.show()

    # multi-step更新机制
    log = {
            'type': 'multi-step',
            "start_lr": 0.01,
            "mult": 0.1,
            "steps": [10, 15, 20]
            }
    lr = build_lr_scheduler(optimizer, log)
    print(lr)
    plt.plot(lr.lr_spaces)
    plt.grid()
    plt.title("MultiStepScheduler")
    plt.xlabel("epochs")
    plt.ylabel("lr")
    plt.show()

    # cos更新机制
    cos = {
            "type": 'cos',
            'start_lr': 0.01,
            'end_lr': 0.0005,
            }
    lr = build_lr_scheduler(optimizer, cos)
    print(lr)
    plt.plot(lr.lr_spaces)
    plt.grid()
    plt.title("CosScheduler")
    plt.xlabel("epochs")
    plt.ylabel("lr")
    plt.show()

    # warmup更新机制,先上升后下降
    step = {
            'type': 'step',
            'start_lr': 0.001,
            'end_lr': 0.03,
            'step': 1,
            }

    warmup = log.copy()
    warmup['warmup'] = step
    warmup['warmup']['epoch'] = 10
    lr = build_lr_scheduler(optimizer, warmup, epochs=55)
    print(lr)
    plt.plot(lr.lr_spaces)
    plt.grid()
    plt.title("WarmupScheduler")
    plt.xlabel("epochs")
    plt.ylabel("lr")
    plt.show()

    lr.step()
    print(lr.last_epoch)

    lr.step(5)
    print(lr.last_epoch)


tracker_config.py

# --------------------------------------------------------
# 目标跟踪器参数设置
# --------------------------------------------------------
from __future__ import division
from utils.anchors import Anchors


# 跟踪器配置参数
class TrackerConfig(object):
    penalty_k = 0.09
    window_influence = 0.39
    lr = 0.38
    seg_thr = 0.3  # 分割阈值 for mask
    windowing = 'cosine'  # 对于较大的位移进行惩罚 to penalize large displacements [cosine/uniform]
    # Params from the network architecture, have to be consistent with the training
    exemplar_size = 127  # 跟踪目标模板的大小 input z size
    instance_size = 255  # 跟踪实例的大小 input x size (search region)
    total_stride = 8  #
    out_size = 63  # for mask
    base_size = 8
    score_size = (instance_size-exemplar_size)//total_stride+1+base_size
    context_amount = 0.5  # 跟踪目标的周边信息比例 context amount for the exemplar
    ratios = [0.33, 0.5, 1, 2, 3]  # anchors宽高比
    scales = [8, ]  # 尺度,即anchor的大小
    anchor_num = len(ratios) * len(scales)  # anchor的个数
    round_dight = 0  #
    anchor = []

    def update(self, newparam=None, anchors=None):
        """
        更新参数
        :param newparam: 新的参数
        :param anchors: anchors的参数
        :return:
        """
        # 新的参数直接添加到配置中
        if newparam:
            for key, value in newparam.items():
                setattr(self, key, value)
        # 添加anchors的参数
        if anchors is not None:
            # 若anchors是字典形式的将其转换为Anchors
            if isinstance(anchors, dict):
                anchors = Anchors(anchors)
            # 更新到config中
            if isinstance(anchors, Anchors):
                self.total_stride = anchors.stride
                self.ratios = anchors.ratios
                self.scales = anchors.scales
                self.round_dight = anchors.round_dight
        self.renew()

    def renew(self):
        """
        更新配置信息
        :return:
        """
        # 分类尺寸
        self.score_size = (self.instance_size - self.exemplar_size) // self.total_stride + 1 + self.base_size
        # anchor数目
        self.anchor_num = len(self.ratios) * len(self.scales)




  • 2
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

あずにゃん

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值