日萌社
人工智能AI:Keras PyTorch MXNet TensorFlow PaddlePaddle 深度学习实战(不定时更新)
CNN:RCNN、SPPNet、Fast RCNN、Faster RCNN、YOLO V1 V2 V3、SSD、FCN、SegNet、U-Net、DeepLab V1 V2 V3、Mask RCNN
单目标跟踪 Siamese系列网络:SiamFC、SiamRPN、one-shot跟踪、one-shotting单样本学习、DaSiamRPN、SiamRPN++、SiamMask
SiamMask_master\tools
config.json
{
"network": {
"arch": "Custom"
},
"hp": {
"instance_size": 255,
"base_size": 8,
"out_size": 127,
"seg_thr": 0.35,
"penalty_k": 0.04,
"window_influence": 0.4,
"lr": 1.0
},
"anchors": {
"stride": 8,
"ratios": [0.33, 0.5, 1, 2, 3],
"scales": [8],
"round_dight": 0
}
}
demo.py
# --------------------------------------------------------
# demo.py
# --------------------------------------------------------
import io
import json
import flask
import torch
import torch.nn.functional as F
from PIL import Image
from torchvision import transforms as T
from tools.test import *
import glob
# 1 创建解析对象
parser = argparse.ArgumentParser(description='PyTorch Tracking Demo')
# 2 添加参数
# 2.1 resume:梗概
parser.add_argument('--resume', default='SiamMask.pth', type=str,
metavar='PATH',help='path to latest checkpoint (default: none)')
# 2.2 config配置
parser.add_argument('--config', dest='config', default='config.json',
help='hyper-parameter of SiamMask in json format')
# 2.3 处理的图像的序列
parser.add_argument('--base_path', default='../data/car', help='datasets')
# 2.4 硬件信息
parser.add_argument('--cpu', action='store_true', help='cpu mode')
# 3 解析参数
args = parser.parse_args()
writer = None
if __name__ == '__main__':
# 1. 设置设备信息 Setup device
# 有GPU时选择GPU,否则使用CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 默认优化运行效率
torch.backends.cudnn.benchmark = True
# 2. 模型设置 Setup Model
# 2.1 将命令行参数解析出来
cfg = load_config(args)
# 2.2 custom是构建的网络,否则引用model中的网络结构
from custom import Custom
siammask = Custom(anchors=cfg['anchors'])
# 2.3 判断是否存在模型的权重文件
if args.resume:
assert isfile(args.resume), 'Please download {} first.'.format(args.resume)
siammask = load_pretrain(siammask, args.resume)
# 在运行推断前,需要调用 model.eval() 函数,以将 dropout 层 和 batch normalization 层设置为评估模式(非训练模式).
# to(device)将张量复制到GPU上,之后的计算将在GPU上运行
siammask.eval().to(device)
# 3. 读取图片序列 Parse Image file
img_files = sorted(glob.glob(join(args.base_path, '*.jp*')))
ims = [cv2.imread(imf) for imf in img_files]
# 4. 选择目标区域 Select ROI
cv2.namedWindow("SiamMask", cv2.WND_PROP_FULLSCREEN)
# cv2.setWindowProperty("SiamMask", cv2.WND_PROP_FULLSCREEN, cv2.WINDOW_FULLSCREEN)
# 5. 将目标框转换为矩形左上角坐标,宽 高的形式
try:
init_rect = cv2.selectROI('SiamMask', ims[0], False, False)
x, y, w, h = init_rect
print(x,y,w,h)
except:
exit()
toc = 0
# 6. 遍历所有的图片
for f, im in enumerate(ims):
# 用于记时:初始的记时周期
tic = cv2.getTickCount()
# 初始化
if f == 0: # init
# 目标位置
target_pos = np.array([x + w / 2, y + h / 2])
# 目标大小
target_sz = np.array([w, h])
# 目标追踪初始化
state = siamese_init(im, target_pos, target_sz, siammask, cfg['hp'], device=device) # init tracker
# 目标跟踪
elif f > 0: # tracking
# 目标追踪,进行state的更新
state = siamese_track(state, im, mask_enable=True, refine_enable=True, device=device) # track
# 确定目标位置
location = state['ploygon'].flatten()
# 生成目标分割的掩码
mask = state['mask'] > state['p'].seg_thr
# 将掩码信息显示在图像上
im[:, :, 2] = (mask > 0) * 255 + (mask == 0) * im[:, :, 2]
# 绘制跟踪目标的位置信息
cv2.polylines(im, [np.int0(location).reshape((-1, 1, 2))], True, (0, 255, 0), 3)
cv2.imshow('SiamMask', im)
key = cv2.waitKey(1)
if key > 0:
break
# 用于记时,获取每一张图片最终的记时周期,并进行统计
toc += cv2.getTickCount() - tic
# 获取全部图片的处理时间
toc /= cv2.getTickFrequency()
# 计算fps
fps = f / toc
print('SiamMask Time: {:02.1f}s Speed: {:3.1f}fps (with visulization!)'.format(toc, fps))
resnet.py
import torch.nn as nn
import torch
from torch.autograd import Variable
import math
import torch.utils.model_zoo as model_zoo
from models.features import Features
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
'resnet152']
# 已进行预训练的resnet模型
model_urls = {
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
}
def conv3x3(in_planes, out_planes, stride=1):
"""
"3x3 convolution with padding"
3*3卷积
:param in_planes: 输入通道数
:param out_planes: 输出通道数
:param stride: 步长
:return:
"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=1, bias=False)
class BasicBlock(nn.Module):
"""
基础的瓶颈模块,由两个叠加的3*3卷积组成,用于res18和res34
"""
# 对输出深度的倍乘
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(BasicBlock, self).__init__()
# 3*3卷积 BN层 Relu激活
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = nn.BatchNorm2d(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = nn.BatchNorm2d(planes)
# shortcut
self.downsample = downsample
self.stride = stride
def forward(self, x):
# 上一层的输出
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
# shortcut存在时
if self.downsample is not None:
# 将上一层的输出x输入downsample,将结果赋给residual
# 目的就是为了应对上下层输出输入深度一致
residual = self.downsample(x)
# 将BN层结果与shortcut相加
out += residual
# relu激活
out = self.relu(out)
return out
class Bottleneck(Features):
"""
瓶颈模块,有1*1 3*3 1*1三个卷积层构成,分别用来降低维度,卷积处理和升高维度
继承在feature,用于特征提取
"""
# 将输入深度进行倍乘(若输入深度为64,那么扩张4倍后就变为了256)
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1):
super(Bottleneck, self).__init__()
# 1*1卷积
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
# BN
self.bn1 = nn.BatchNorm2d(planes)
# padding = (2 - stride) + (dilation // 2 - 1)
padding = 2 - stride
assert stride==1 or dilation==1, "stride and dilation must have one equals to zero at least"
if dilation > 1:
padding = dilation
# 3*3 卷积
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
padding=padding, bias=False, dilation=dilation)
self.bn2 = nn.BatchNorm2d(planes)
# 1*1 卷积
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * 4)
self.relu = nn.ReLU(inplace=True)
# shortcut
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
residual = self.downsample(x)
if out.size() != residual.size():
print(out.size(), residual.size())
out += residual
out = self.relu(out)
return out
class Bottleneck_nop(nn.Module):
"""
官网原始的瓶颈块
"""
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(Bottleneck_nop, self).__init__()
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
padding=0, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * 4)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
residual = self.downsample(x)
s = residual.size(3)
residual = residual[:, :, 1:s-1, 1:s-1]
out += residual
out = self.relu(out)
return out
class ResNet(nn.Module):
"""
ResNet主体部分实现
"""
def __init__(self, block, layers, layer4=False, layer3=False):
"""
主体实现
:param block: 基础块:BasicBlock或者BottleNeck
:param layers: 每个大的layer中的block个数
:param layer4:
:param layer3:是否加入layer3和layer4
"""
# 输入深度
self.inplanes = 64
super(ResNet, self).__init__()
# 卷积
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=0,
bias=False)
# BN
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)
# 池化
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
# 将block块添加到layer中
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2) # 31x31, 15x15
self.feature_size = 128 * block.expansion
# 添加layer3
if layer3:
self.layer3 = self._make_layer(block, 256, layers[2], stride=1, dilation=2) # 15x15, 7x7
self.feature_size = (256 + 128) * block.expansion
else:
self.layer3 = lambda x:x # identity
# 添加layer4
if layer4:
self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=4) # 7x7, 3x3
self.feature_size = 512 * block.expansion
else:
self.layer4 = lambda x:x # identity
# 参数初始化
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
def _make_layer(self, block, planes, blocks, stride=1, dilation=1):
"""
:param block: basicbottle或者是bottleneck
:param planes:通道数
:param blocks:添加block的个数
:param stride:
:param dilation:
:return:
"""
downsample = None
dd = dilation
# shortcut的设置
if stride != 1 or self.inplanes != planes * block.expansion:
if stride == 1 and dilation == 1:
downsample = nn.Sequential(
nn.Conv2d(self.inplanes, planes * block.expansion,
kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(planes * block.expansion),
)
else:
if dilation > 1:
dd = dilation // 2
padding = dd
else:
dd = 1
padding = 0
downsample = nn.Sequential(
nn.Conv2d(self.inplanes, planes * block.expansion,
kernel_size=3, stride=stride, bias=False,
padding=padding, dilation=dd),
nn.BatchNorm2d(planes * block.expansion),
)
layers = []
# layers.append(block(self.inplanes, planes, stride, downsample, dilation=dilation))
layers.append(block(self.inplanes, planes, stride, downsample, dilation=dd))
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes, dilation=dilation))
# 构建网络
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
p0 = self.relu(x)
x = self.maxpool(p0)
p1 = self.layer1(x)
p2 = self.layer2(p1)
p3 = self.layer3(p2)
return p0, p1, p2, p3
class ResAdjust(nn.Module):
"""
对模块进行adjust,未使用
"""
def __init__(self,
block=Bottleneck,
out_channels=256,
adjust_number=1,
fuse_layers=[2,3,4]):
super(ResAdjust, self).__init__()
self.fuse_layers = set(fuse_layers)
if 2 in self.fuse_layers:
self.layer2 = self._make_layer(block, 128, 1, out_channels, adjust_number)
if 3 in self.fuse_layers:
self.layer3 = self._make_layer(block, 256, 2, out_channels, adjust_number)
if 4 in self.fuse_layers:
self.layer4 = self._make_layer(block, 512, 4, out_channels, adjust_number)
self.feature_size = out_channels * len(self.fuse_layers)
def _make_layer(self, block, plances, dilation, out, number=1):
layers = []
for _ in range(number):
layer = block(plances * block.expansion, plances, dilation=dilation)
layers.append(layer)
downsample = nn.Sequential(
nn.Conv2d(plances * block.expansion, out, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(out)
)
layers.append(downsample)
return nn.Sequential(*layers)
def forward(self, p2, p3, p4):
outputs = []
if 2 in self.fuse_layers:
outputs.append(self.layer2(p2))
if 3 in self.fuse_layers:
outputs.append(self.layer3(p3))
if 4 in self.fuse_layers:
outputs.append(self.layer4(p4))
# return torch.cat(outputs, 1)
return outputs
def resnet18(pretrained=False, **kwargs):
"""Constructs a ResNet-18 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
return model
def resnet34(pretrained=False, **kwargs):
"""Constructs a ResNet-34 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet34']))
return model
def resnet50(pretrained=False, **kwargs):
"""Constructs a ResNet-50 model.
重构resnet50
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
return model
def resnet101(pretrained=False, **kwargs):
"""Constructs a ResNet-101 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet101']))
return model
def resnet152(pretrained=False, **kwargs):
"""Constructs a ResNet-152 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet152']))
return model
if __name__ == '__main__':
net = resnet50()
print(net)
# net = net.cuda()
#
# var = torch.FloatTensor(1,3,127,127).cuda()
# var = Variable(var)
#
# net(var)
# print('*************')
# var = torch.FloatTensor(1,3,255,255).cuda()
# var = Variable(var)
# net(var)
var = torch.FloatTensor(1,3,127,127)
var = Variable(var)
print(var)
out = net(var)
print(out)
test.py
# --------------------------------------------------------
# SiamMask
# Licensed under The MIT License
# Written by Qiang Wang (wangqiang2015 at ia.ac.cn)
# --------------------------------------------------------
from __future__ import division
import argparse
import logging
import numpy as np
import cv2
from PIL import Image
from os import makedirs
from os.path import join, isdir, isfile
from utils.log_helper import init_log, add_file_handler
from utils.load_helper import load_pretrain
from utils.bbox_helper import get_axis_aligned_bbox, cxy_wh_2_rect
from utils.benchmark_helper import load_dataset, dataset_zoo
import torch
from torch.autograd import Variable
import torch.nn.functional as F
from utils.anchors import Anchors
from utils.tracker_config import TrackerConfig
from utils.config_helper import load_config
# from utils.pyvotkit.region import vot_overlap, vot_float2str
# 在目标分割中将某一像素作为目标的阈值
thrs = np.arange(0.3, 0.5, 0.05)
# 参数信息配置
parser = argparse.ArgumentParser(description='Test SiamMask')
parser.add_argument('--arch', dest='arch', default='', choices=['Custom',],
help='architecture of pretrained model')
parser.add_argument('--config', dest='config', required=True, help='hyper-parameter for SiamMask')
parser.add_argument('--resume', default='', type=str, required=True,
metavar='PATH', help='path to latest checkpoint (default: none)')
parser.add_argument('--mask', action='store_true', help='whether use mask output')
parser.add_argument('--refine', action='store_true', help='whether use mask refine output')
parser.add_argument('--dataset', dest='dataset', default='VOT2018', choices=dataset_zoo,
help='datasets')
parser.add_argument('-l', '--log', default="log_test.txt", type=str, help='log file')
parser.add_argument('-v', '--visualization', dest='visualization', action='store_true',
help='whether visualize result')
parser.add_argument('--save_mask', action='store_true', help='whether use save mask for davis')
parser.add_argument('--gt', action='store_true', help='whether use gt rect for davis (Oracle)')
parser.add_argument('--video', default='', type=str, help='test special video')
parser.add_argument('--cpu', action='store_true', help='cpu mode')
parser.add_argument('--debug', action='store_true', help='debug mode')
def to_torch(ndarray):
'''
将数据转换为TorcH tensor
:param ndarray: ndarray
:return: torch中的tensor
'''
if type(ndarray).__module__ == 'numpy':
return torch.from_numpy(ndarray)
elif not torch.is_tensor(ndarray):
raise ValueError("Cannot convert {} to torch tensor"
.format(type(ndarray)))
return ndarray
def im_to_torch(img):
'''
将图像转换为torch中的tensor
:param img: 输入图像
:return: 输出张量
'''
img = np.transpose(img, (2, 0, 1)) # C*H*W
img = to_torch(img).float()
return img
def get_subwindow_tracking(im, pos, model_sz, original_sz, avg_chans, out_mode='torch'):
"""
获取跟踪目标的信息(图像窗口)
:param im:跟踪的模板图像
:param pos:目标位置
:param model_sz:模型要求输入的目标尺寸
:param original_sz: 扩展后的目标尺寸
:param avg_chans:图像的平均值
:param out_mode: 输出模式
:return:
"""
if isinstance(pos, float):
# 目标中心点坐标
pos = [pos, pos]
# 目标的尺寸
sz = original_sz
# 图像尺寸
im_sz = im.shape
# 扩展背景后边界到中心的距离
c = (original_sz + 1) / 2
# 判断目标是否超出图像边界,若超出边界则对图像进行填充
context_xmin = round(pos[0] - c)
context_xmax = context_xmin + sz - 1
context_ymin = round(pos[1] - c)
context_ymax = context_ymin + sz - 1
left_pad = int(max(0., -context_xmin))
top_pad = int(max(0., -context_ymin))
right_pad = int(max(0., context_xmax - im_sz[1] + 1))
bottom_pad = int(max(0., context_ymax - im_sz[0] + 1))
# 图像填充使得图像的原点发生变化,计算填充后图像块的坐标
context_xmin = context_xmin + left_pad
context_xmax = context_xmax + left_pad
context_ymin = context_ymin + top_pad
context_ymax = context_ymax + top_pad
# zzp: a more easy speed version
r, c, k = im.shape
# 若进行填充需对目标位置重新赋值
if any([top_pad, bottom_pad, left_pad, right_pad]):
# 生成与填充后图像同样大小的全零数组
te_im = np.zeros((r + top_pad + bottom_pad, c + left_pad + right_pad, k), np.uint8)
# 对原图像区域进行赋值
te_im[top_pad:top_pad + r, left_pad:left_pad + c, :] = im
# 将填充区域赋值为图像的均值
if top_pad:
te_im[0:top_pad, left_pad:left_pad + c, :] = avg_chans
if bottom_pad:
te_im[r + top_pad:, left_pad:left_pad + c, :] = avg_chans
if left_pad:
te_im[:, 0:left_pad, :] = avg_chans
if right_pad:
te_im[:, c + left_pad:, :] = avg_chans
# 根据填充结果修改目标的位置
im_patch_original = te_im[int(context_ymin):int(context_ymax + 1), int(context_xmin):int(context_xmax + 1), :]
else:
im_patch_original = im[int(context_ymin):int(context_ymax + 1), int(context_xmin):int(context_xmax + 1), :]
# 若跟踪目标块的尺寸与模型输入尺寸不同,则利用opencv修改图像尺寸
if not np.array_equal(model_sz, original_sz):
im_patch = cv2.resize(im_patch_original, (model_sz, model_sz))
else:
im_patch = im_patch_original
# cv2.imshow('crop', im_patch)
# cv2.waitKey(0)
# 若输出模式是Torch,则将其通道调换,否则直接输出im_patch
return im_to_torch(im_patch) if out_mode in 'torch' else im_patch
def generate_anchor(cfg, score_size):
"""
生成锚点:anchor
:param cfg: anchor的配置信息
:param score_size:分类的评分结果
:return:生成的anchor
"""
# 初始化anchor
anchors = Anchors(cfg)
# 得到生成的anchors
anchor = anchors.anchors
# 得到每一个anchor的左上角和右下角坐标
x1, y1, x2, y2 = anchor[:, 0], anchor[:, 1], anchor[:, 2], anchor[:, 3]
# 将anchor转换为中心点坐标和宽高的形式
anchor = np.stack([(x1+x2)*0.5, (y1+y2)*0.5, x2-x1, y2-y1], 1)
# 获取生成anchor的范围
total_stride = anchors.stride
# 获取锚点的个数
anchor_num = anchor.shape[0]
# 将对锚点组进行广播,并设置其坐标。
anchor = np.tile(anchor, score_size * score_size).reshape((-1, 4))
# 加上ori偏移后,xx和yy以图像中心为原点
ori = - (score_size // 2) * total_stride
xx, yy = np.meshgrid([ori + total_stride * dx for dx in range(score_size)],
[ori + total_stride * dy for dy in range(score_size)])
xx, yy = np.tile(xx.flatten(), (anchor_num, 1)).flatten(), \
np.tile(yy.flatten(), (anchor_num, 1)).flatten()
# 获取anchor
anchor[:, 0], anchor[:, 1] = xx.astype(np.float32), yy.astype(np.float32)
return anchor
def siamese_init(im, target_pos, target_sz, model, hp=None, device='cpu'):
"""
初始化跟踪器,根据目标的信息构建state 字典
:param im: 当前处理的图像
:param target_pos: 目标的位置
:param target_sz: 目标的尺寸
:param model: 训练好的网络模型
:param hp: 超参数
:param device: 硬件信息
:return: 跟踪器的state字典数据
"""
# 初始化state字典
state = dict()
# 设置图像的宽高
state['im_h'] = im.shape[0]
state['im_w'] = im.shape[1]
# 配置跟踪器的相关参数
p = TrackerConfig()
# 对参数进行更新
p.update(hp, model.anchors)
# 更新参数
p.renew()
# 获取网络模型
net = model
# 根据网络参数对跟踪器的参数进行更新,主要是anchors
p.scales = model.anchors['scales']
p.ratios = model.anchors['ratios']
p.anchor_num = model.anchor_num
# 生成锚点
p.anchor = generate_anchor(model.anchors, p.score_size)
# 图像的平均值
avg_chans = np.mean(im, axis=(0, 1))
# 根据设置的上下文比例,输入z 的宽高及尺寸
wc_z = target_sz[0] + p.context_amount * sum(target_sz)
hc_z = target_sz[1] + p.context_amount * sum(target_sz)
s_z = round(np.sqrt(wc_z * hc_z))
# 初始化跟踪目标 initialize the exemplar
z_crop = get_subwindow_tracking(im, target_pos, p.exemplar_size, s_z, avg_chans)
# 将其转换为Variable可在pythorch中进行反向传播
z = Variable(z_crop.unsqueeze(0))
# 专门处理模板
net.template(z.to(device))
# 设置使用的惩罚窗口
if p.windowing == 'cosine':
# 利用hanning窗的外积生成cosine窗口
window = np.outer(np.hanning(p.score_size), np.hanning(p.score_size))
elif p.windowing == 'uniform':
window = np.ones((p.score_size, p.score_size))
# 每一个anchor都有一个对应的惩罚窗口
window = np.tile(window.flatten(), p.anchor_num)
# 将信息更新到state字典中
state['p'] = p
state['net'] = net
state['avg_chans'] = avg_chans
state['window'] = window
state['target_pos'] = target_pos
state['target_sz'] = target_sz
return state
def siamese_track(state, im, mask_enable=False, refine_enable=False, device='cpu', debug=False):
"""
对目标进行跟踪
:param state:目标状态
:param im:跟踪的图像帧
:param mask_enable:是否进行掩膜
:param refine_enable:是否进行特征的融合
:param device:硬件信息
:param debug: 是否进行debug
:return:跟踪目标的状态 state字典
"""
# 获取目标状态
p = state['p']
net = state['net']
avg_chans = state['avg_chans']
window = state['window']
target_pos = state['target_pos']
target_sz = state['target_sz']
# 包含周边信息的跟踪框的宽度,高度,尺寸
wc_x = target_sz[1] + p.context_amount * sum(target_sz)
hc_x = target_sz[0] + p.context_amount * sum(target_sz)
s_x = np.sqrt(wc_x * hc_x)
# 模板模型输入框尺寸与跟踪框的比例
scale_x = p.exemplar_size / s_x
# 使用与模板分支相同的比例得到检测区域
d_search = (p.instance_size - p.exemplar_size) / 2
pad = d_search / scale_x
s_x = s_x + 2 * pad
# 对检测框进行扩展,包含周边信息
crop_box = [target_pos[0] - round(s_x) / 2, target_pos[1] - round(s_x) / 2, round(s_x), round(s_x)]
# 若进行debug
if debug:
# 复制图片
im_debug = im.copy()
# 产生crop_box
crop_box_int = np.int0(crop_box)
# 将其绘制在图片上
cv2.rectangle(im_debug, (crop_box_int[0], crop_box_int[1]),
(crop_box_int[0] + crop_box_int[2], crop_box_int[1] + crop_box_int[3]), (255, 0, 0), 2)
# 图片展示
cv2.imshow('search area', im_debug)
cv2.waitKey(0)
# extract scaled crops for search region x at previous target position
# 将目标位置按比例转换为要跟踪的目标
x_crop = Variable(get_subwindow_tracking(im, target_pos, p.instance_size, round(s_x), avg_chans).unsqueeze(0))
# 调用网络进行目标跟踪
if mask_enable:
# 进行目标分割
score, delta, mask = net.track_mask(x_crop.to(device))
else:
# 只进行目标追踪,不进行分割
score, delta = net.track(x_crop.to(device))
# 目标框回归结果(将其转成4*...的样式)
delta = delta.permute(1, 2, 3, 0).contiguous().view(4, -1).data.cpu().numpy()
# 目标分类结果(将其转成2*...的样式)
score = F.softmax(score.permute(1, 2, 3, 0).contiguous().view(2, -1).permute(1, 0), dim=1).data[:,
1].cpu().numpy()
# 计算目标框的中心点坐标,delta[0],delta[1],以及宽delta[2]和高delta[3],这里变量不是很明确。
delta[0, :] = delta[0, :] * p.anchor[:, 2] + p.anchor[:, 0]
delta[1, :] = delta[1, :] * p.anchor[:, 3] + p.anchor[:, 1]
delta[2, :] = np.exp(delta[2, :]) * p.anchor[:, 2]
delta[3, :] = np.exp(delta[3, :]) * p.anchor[:, 3]
def change(r):
"""
将r和1/r逐位比较取最大值
:param r:
:return:
"""
return np.maximum(r, 1. / r)
def sz(w, h):
"""
计算等效边长
:param w: 宽
:param h: 高
:return: 等效边长
"""
pad = (w + h) * 0.5
sz2 = (w + pad) * (h + pad)
return np.sqrt(sz2)
def sz_wh(wh):
"""
计算等效边长
:param wh: 宽高的数组
:return: 等效边长
"""
pad = (wh[0] + wh[1]) * 0.5
sz2 = (wh[0] + pad) * (wh[1] + pad)
return np.sqrt(sz2)
# 尺寸惩罚 size penalty
target_sz_in_crop = target_sz*scale_x
s_c = change(sz(delta[2, :], delta[3, :]) / (sz_wh(target_sz_in_crop))) # scale penalty
r_c = change((target_sz_in_crop[0] / target_sz_in_crop[1]) / (delta[2, :] / delta[3, :])) # ratio penalty
# p.penalty_k超参数
penalty = np.exp(-(r_c * s_c - 1) * p.penalty_k)
# 对分类结果进行惩罚
pscore = penalty * score
# cos window (motion model)
# 窗口惩罚:按一定权值叠加一个窗分布值
pscore = pscore * (1 - p.window_influence) + window * p.window_influence
# 获取最优权值的索引
best_pscore_id = np.argmax(pscore)
# 将最优的预测结果映射回原图
pred_in_crop = delta[:, best_pscore_id] / scale_x
# 计算lr
lr = penalty[best_pscore_id] * score[best_pscore_id] * p.lr # lr for OTB
# 计算目标的位置和尺寸:根据预测偏移得到目标位置和尺寸
res_x = pred_in_crop[0] + target_pos[0]
res_y = pred_in_crop[1] + target_pos[1]
res_w = target_sz[0] * (1 - lr) + pred_in_crop[2] * lr
res_h = target_sz[1] * (1 - lr) + pred_in_crop[3] * lr
# 目标的位置和尺寸
target_pos = np.array([res_x, res_y])
target_sz = np.array([res_w, res_h])
# for Mask Branch
# 若进行分割
if mask_enable:
# 获取最优预测结果的位置索引:np.unravel_index:将平面索引或平面索引数组转换为坐标数组的元组
best_pscore_id_mask = np.unravel_index(best_pscore_id, (5, p.score_size, p.score_size))
delta_x, delta_y = best_pscore_id_mask[2], best_pscore_id_mask[1]
# 是否进行特征融合
if refine_enable:
# 调用track_refine,运行 Refine 模块,由相关特征图上 1×1×256 的特征向量与检测下采样前的特征图得到目标掩膜
mask = net.track_refine((delta_y, delta_x)).to(device).sigmoid().squeeze().view(
p.out_size, p.out_size).cpu().data.numpy()
else:
# 不进行融合时直接生成掩膜数据
mask = mask[0, :, delta_y, delta_x].sigmoid(). \
squeeze().view(p.out_size, p.out_size).cpu().data.numpy()
def crop_back(image, bbox, out_sz, padding=-1):
"""
对图像进行仿射变换
:param image: 图像
:param bbox:
:param out_sz: 输出尺寸
:param padding: 是否进行扩展
:return: 仿射变换后的结果
"""
# 构造变换矩阵
# 尺度系数
a = (out_sz[0] - 1) / bbox[2]
b = (out_sz[1] - 1) / bbox[3]
# 平移量
c = -a * bbox[0]
d = -b * bbox[1]
mapping = np.array([[a, 0, c],
[0, b, d]]).astype(np.float)
# 进行仿射变换
crop = cv2.warpAffine(image, mapping, (out_sz[0], out_sz[1]),
flags=cv2.INTER_LINEAR,
borderMode=cv2.BORDER_CONSTANT,
borderValue=padding)
return crop
# 检测区域框长度与输入模型的大小的比值:缩放系数
s = crop_box[2] / p.instance_size
# 预测的模板区域框
sub_box = [crop_box[0] + (delta_x - p.base_size / 2) * p.total_stride * s,
crop_box[1] + (delta_y - p.base_size / 2) * p.total_stride * s,
s * p.exemplar_size, s * p.exemplar_size]
# 缩放系数
s = p.out_size / sub_box[2]
# 背景框
back_box = [-sub_box[0] * s, -sub_box[1] * s, state['im_w'] * s, state['im_h'] * s]
# 仿射变换
mask_in_img = crop_back(mask, back_box, (state['im_w'], state['im_h']))
# 得到掩膜结果
target_mask = (mask_in_img > p.seg_thr).astype(np.uint8)
# 根据cv2的版本查找轮廓
if cv2.__version__[-5] == '4':
# opencv4中返回的参数只有两个,其他版本有四个
contours, _ = cv2.findContours(target_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
else:
_, contours, _ = cv2.findContours(target_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
# 获取轮廓的面积
cnt_area = [cv2.contourArea(cnt) for cnt in contours]
if len(contours) != 0 and np.max(cnt_area) > 100:
# 获取面积最大的轮廓
contour = contours[np.argmax(cnt_area)] # use max area polygon
# 转换为...*2的形式
polygon = contour.reshape(-1, 2)
# pbox = cv2.boundingRect(polygon) # Min Max Rectangle
# 得到最小外接矩形后找到该矩形的四个顶点
prbox = cv2.boxPoints(cv2.minAreaRect(polygon)) # Rotated Rectangle
# box_in_img = pbox
# 获得跟踪框
rbox_in_img = prbox
else: # empty mask
# 根据预测的目标位置和尺寸得到location
location = cxy_wh_2_rect(target_pos, target_sz)
# 得到跟踪框的四个顶点
rbox_in_img = np.array([[location[0], location[1]],
[location[0] + location[2], location[1]],
[location[0] + location[2], location[1] + location[3]],
[location[0], location[1] + location[3]]])
# 得到目标的位置和尺寸
target_pos[0] = max(0, min(state['im_w'], target_pos[0]))
target_pos[1] = max(0, min(state['im_h'], target_pos[1]))
target_sz[0] = max(10, min(state['im_w'], target_sz[0]))
target_sz[1] = max(10, min(state['im_h'], target_sz[1]))
# 更新state对象
state['target_pos'] = target_pos
state['target_sz'] = target_sz
state['score'] = score[best_pscore_id]
state['mask'] = mask_in_img if mask_enable else []
state['ploygon'] = rbox_in_img if mask_enable else []
return state
def track_vot(model, video, hp=None, mask_enable=False, refine_enable=False, device='cpu'):
"""
对目标进行追踪
:param model: 训练好的模型
:param video: 视频数据
:param hp: 超参数
:param mask_enable: 是否生成掩膜,默认为False
:param refine_enable: 是否使用融合后的模型
:param device:硬件信息
:return:目标跟丢次数,fps
"""
# 记录目标框及其状态
regions = [] # result and states[1 init / 2 lost / 0 skip]
# 获取要处理的图像,和真实值groundtruth
image_files, gt = video['image_files'], video['gt']
# 设置相关参数:初始帧,终止帧,目标丢失次数,toc
start_frame, end_frame, lost_times, toc = 0, len(image_files), 0, 0
# 遍历要处理的图像
for f, image_file in enumerate(image_files):
# 读取图像
im = cv2.imread(image_file)
tic = cv2.getTickCount()
# 若为初始帧图像
if f == start_frame: # init
# 获取目标区域的位置:中心点坐标,宽,高
cx, cy, w, h = get_axis_aligned_bbox(gt[f])
# 目标位置
target_pos = np.array([cx, cy])
# 目标大小
target_sz = np.array([w, h])
# 初始化跟踪器
state = siamese_init(im, target_pos, target_sz, model, hp, device) # init tracker
# 将目标框转换为:左上角坐标,宽,高的形式
location = cxy_wh_2_rect(state['target_pos'], state['target_sz'])
# 若数据集是VOT,在regions中添加1,否则添加gt[f],第一帧目标的真实位置
regions.append(1 if 'VOT' in args.dataset else gt[f])
# 非初始帧数据
elif f > start_frame: # tracking
# 进行目标追踪
state = siamese_track(state, im, mask_enable, refine_enable, device, args.debug) # track
# 若进行掩膜处理
if mask_enable:
# 将跟踪结果铺展开
location = state['ploygon'].flatten()
# 获得掩码
mask = state['mask']
# 不进行掩膜处理
else:
# 将目标框表示形式转换为:左上角坐标,宽,高的形式
location = cxy_wh_2_rect(state['target_pos'], state['target_sz'])
# 掩膜为空
mask = []
# 如果是VOT数据,计算交叠程度,其他数据默认交叠为1
if 'VOT' in args.dataset:
# 目标的真实位置
gt_polygon = ((gt[f][0], gt[f][1]), (gt[f][2], gt[f][3]),
(gt[f][4], gt[f][5]), (gt[f][6], gt[f][7]))
# 若进行掩膜处理
if mask_enable:
# 预测结果为:
pred_polygon = ((location[0], location[1]), (location[2], location[3]),
(location[4], location[5]), (location[6], location[7]))
# 若不进行掩膜
else:
# 预测结果为:
pred_polygon = ((location[0], location[1]),
(location[0] + location[2], location[1]),
(location[0] + location[2], location[1] + location[3]),
(location[0], location[1] + location[3]))
# 计算两个目标之间的交叠程度
b_overlap = vot_overlap(gt_polygon, pred_polygon, (im.shape[1], im.shape[0]))
else:
b_overlap = 1
# 如果跟踪框和真实框有交叠,添加跟踪结果中
if b_overlap:
regions.append(location)
# 如果跟丢,则记录跟丢次数,五帧后重新进行目标初始化
else: # lost
regions.append(2)
lost_times += 1
start_frame = f + 5 # skip 5 frames
# 其他帧数据跳过(比如小于初始帧的数据)
else: # skip
regions.append(0)
# 计算跟踪时间
toc += cv2.getTickCount() - tic
# 如果进行显示并且跳过丢失的帧数据
if args.visualization and f >= start_frame: # visualization (skip lost frame)
# 复制原图像的副本
im_show = im.copy()
# 如果帧数为0,销毁窗口
if f == 0: cv2.destroyAllWindows()
# 标注信息中包含第f帧的结果时:
if gt.shape[0] > f:
# 将标准的真实信息绘制在图像上
if len(gt[f]) == 8:
cv2.polylines(im_show, [np.array(gt[f], np.int).reshape((-1, 1, 2))], True, (0, 255, 0), 3)
else:
cv2.rectangle(im_show, (gt[f, 0], gt[f, 1]), (gt[f, 0] + gt[f, 2], gt[f, 1] + gt[f, 3]), (0, 255, 0), 3)
# 将跟踪结果绘制在图像上
if len(location) == 8:
# 若进行掩膜处理,将掩膜结果绘制在图像上
if mask_enable:
mask = mask > state['p'].seg_thr
im_show[:, :, 2] = mask * 255 + (1 - mask) * im_show[:, :, 2]
location_int = np.int0(location)
cv2.polylines(im_show, [location_int.reshape((-1, 1, 2))], True, (0, 255, 255), 3)
else:
location = [int(l) for l in location]
cv2.rectangle(im_show, (location[0], location[1]),
(location[0] + location[2], location[1] + location[3]), (0, 255, 255), 3)
cv2.putText(im_show, str(f), (40, 40), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 255), 2)
cv2.putText(im_show, str(lost_times), (40, 80), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)
cv2.putText(im_show, str(state['score']) if 'score' in state else '', (40, 120), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)
cv2.imshow(video['name'], im_show)
cv2.waitKey(1)
toc /= cv2.getTickFrequency()
# 结果保存到文本文件中 save result
# 文件夹名称:包括模型结构、mask、refine、resume信息
name = args.arch.split('.')[0] + '_' + ('mask_' if mask_enable else '') + ('refine_' if refine_enable else '') +\
args.resume.split('/')[-1].split('.')[0]
# 如果是VOT数据集
if 'VOT' in args.dataset:
# 构建追踪结果存储位置
video_path = join('test', args.dataset, name,
'baseline', video['name'])
# 若不存在该路径,进行创建
if not isdir(video_path): makedirs(video_path)
# 文本文件的路径
result_path = join(video_path, '{:s}_001.txt'.format(video['name']))
# 将追踪结果写入文本文件中
# with open(result_path, "w") as fin:
# for x in regions:
# fin.write("{:d}\n".format(x)) if isinstance(x, int) else \
# fin.write(','.join([vot_float2str("%.4f", i) for i in x]) + '\n')
# 如果是OTB数据
else: # OTB
# 构建存储路径
video_path = join('test', args.dataset, name)
# 若不存在该路径,进行创建
if not isdir(video_path): makedirs(video_path)
# 文本文件的路径
result_path = join(video_path, '{:s}.txt'.format(video['name']))
# 将追踪结果写入文本文件中
with open(result_path, "w") as fin:
for x in regions:
fin.write(','.join([str(i) for i in x])+'\n')
# 将信息写入到log文件中
logger.info('({:d}) Video: {:12s} Time: {:02.1f}s Speed: {:3.1f}fps Lost: {:d}'.format(
v_id, video['name'], toc, f / toc, lost_times))
# 返回结果
return lost_times, f / toc
def MultiBatchIouMeter(thrs, outputs, targets, start=None, end=None):
"""
批量计算某个目标在视频(多帧图像)中的IOU
:param thrs:阈值
:param outputs:追踪的目标结果
:param targets:真实的目标结果
:param start:起止帧
:param end:终止帧
:return:某个目标的区域相似度
"""
# 将追踪结果与真实结果转换为ndarray的形式
targets = np.array(targets)
outputs = np.array(outputs)
# 利用标注信息获取视频的帧数
num_frame = targets.shape[0]
# 若未指定初始帧
if start is None:
# 根据目标跟踪结果确定目标ids
object_ids = np.array(list(range(outputs.shape[0]))) + 1
else:
# 根据指定初始帧确定目标的ids
object_ids = [int(id) for id in start]
# 确定目标个数
num_object = len(object_ids)
# 用来存储某一目标的交并比
res = np.zeros((num_object, len(thrs)), dtype=np.float32)
# 计算掩膜中的最大值及其所在id(该位置认为是目标的位置)
output_max_id = np.argmax(outputs, axis=0).astype('uint8')+1
outputs_max = np.max(outputs, axis=0)
# 遍历阈值
for k, thr in enumerate(thrs):
# 若追踪的max大于阈值, output_thr设为1,否则设为0
output_thr = outputs_max > thr
# 遍历追踪的目标
for j in range(num_object):
# 得到指定的目标
target_j = targets == object_ids[j]
# 确定目标所在的视频帧数
if start is None:
start_frame, end_frame = 1, num_frame - 1
else:
start_frame, end_frame = start[str(object_ids[j])] + 1, end[str(object_ids[j])] - 1
# 交并比
iou = []
# 遍历帧
for i in range(start_frame, end_frame):
# 找到追踪结果为j的位置置为1
pred = (output_thr[i] * output_max_id[i]) == (j+1)
# 计算真值和追踪结果的和
mask_sum = (pred == 1).astype(np.uint8) + (target_j[i] > 0).astype(np.uint8)
# 计算交
intxn = np.sum(mask_sum == 2)
# 计算并
union = np.sum(mask_sum > 0)
# 计算交并比
if union > 0:
iou.append(intxn / union)
elif union == 0 and intxn == 0:
iou.append(1)
# 计算目标j,阈值为k时的平均交并比
res[j, k] = np.mean(iou)
return res
def track_vos(model, video, hp=None, mask_enable=False, refine_enable=False, mot_enable=False, device='cpu'):
"""
对数据进行分割并追踪
:param model: 训练好的模型
:param video: 视频数据
:param hp: 超参数
:param mask_enable: 是否生成掩膜,默认为False
:param refine_enable: 是否使用融合后的模型
:param mot_enable:是否进行多目标追踪
:param device:硬件信息
:return:区域相似度(掩膜与真值之间的IOU),fps
"""
# 要处理的图像序列
image_files = video['image_files']
# 标注信息:分割中标注的内容也是图像
annos = [np.array(Image.open(x)) for x in video['anno_files']]
# 获取初始帧的标注信息
if 'anno_init_files' in video:
annos_init = [np.array(Image.open(x)) for x in video['anno_init_files']]
else:
annos_init = [annos[0]]
# 如不进行多目标跟踪,则把多个实例合并为一个示例后进行跟踪
if not mot_enable:
# 将标注信息中大于0的置为1,存为掩膜的形式
annos = [(anno > 0).astype(np.uint8) for anno in annos]
annos_init = [(anno_init > 0).astype(np.uint8) for anno_init in annos_init]
# 统计起始帧图像中的目标id
if 'start_frame' in video:
object_ids = [int(id) for id in video['start_frame']]
else:
# 若起始帧不存在,则根据初始的标注信息确定目标id
object_ids = [o_id for o_id in np.unique(annos[0]) if o_id != 0]
# 若目标idgeshu小于初始帧的标注个数时,说明不进行多目标追踪,合并后的标注信息作为每个目标的标准信息
if len(object_ids) != len(annos_init):
annos_init = annos_init*len(object_ids)
# 统计跟踪目标个数
object_num = len(object_ids)
toc = 0
# 用来存放每一帧图像的掩模信息
pred_masks = np.zeros((object_num, len(image_files), annos[0].shape[0], annos[0].shape[1]))-1
# 遍历每一个目标,在起止帧之间进行目标跟踪
for obj_id, o_id in enumerate(object_ids):
# 确定起止帧的id
if 'start_frame' in video:
start_frame = video['start_frame'][str(o_id)]
end_frame = video['end_frame'][str(o_id)]
else:
start_frame, end_frame = 0, len(image_files)
# 遍历每一帧图像
for f, image_file in enumerate(image_files):
im = cv2.imread(image_file)
tic = cv2.getTickCount()
# 若是起始帧图像,进行初始化
if f == start_frame: # init
# 确定目标o_id的掩模
mask = annos_init[obj_id] == o_id
# 计算mask垂直边界的最小矩形(矩形与图像的上下边界平行)
x, y, w, h = cv2.boundingRect((mask).astype(np.uint8))
# 计算边界矩形的中心坐标
cx, cy = x + w/2, y + h/2
# 目标位置:矩形中心
target_pos = np.array([cx, cy])
# 目标尺寸:矩形尺寸
target_sz = np.array([w, h])
# 初始化跟踪器
state = siamese_init(im, target_pos, target_sz, model, hp, device=device) # init tracker
# 若非起始帧图像,则执行跟踪操作
elif end_frame >= f > start_frame: # tracking
# 目标跟踪
state = siamese_track(state, im, mask_enable, refine_enable, device=device) # track
# 某一帧图像掩膜信息
mask = state['mask']
toc += cv2.getTickCount() - tic
# 所有帧图像,更新掩膜信息
if end_frame >= f >= start_frame:
# 更新所有帧图像的某一目标的掩膜
pred_masks[obj_id, f, :, :] = mask
toc /= cv2.getTickFrequency()
# 若标注信息与测试图像长度一致,计算区域相似度
if len(annos) == len(image_files):
# 批量计算IOU
multi_mean_iou = MultiBatchIouMeter(thrs, pred_masks, annos,
start=video['start_frame'] if 'start_frame' in video else None,
end=video['end_frame'] if 'end_frame' in video else None)
# 将每一目标的IOU写入到日志文件中
for i in range(object_num):
for j, thr in enumerate(thrs):
logger.info('Fusion Multi Object{:20s} IOU at {:.2f}: {:.4f}'.format(video['name'] + '_' + str(i + 1), thr,
multi_mean_iou[i, j]))
else:
multi_mean_iou = []
# 保存掩膜
if args.save_mask:
video_path = join('test', args.dataset, 'SiamMask', video['name'])
if not isdir(video_path): makedirs(video_path)
pred_mask_final = np.array(pred_masks)
pred_mask_final = (np.argmax(pred_mask_final, axis=0).astype('uint8') + 1) * (
np.max(pred_mask_final, axis=0) > state['p'].seg_thr).astype('uint8')
for i in range(pred_mask_final.shape[0]):
cv2.imwrite(join(video_path, image_files[i].split('/')[-1].split('.')[0] + '.png'), pred_mask_final[i].astype(np.uint8))
# 显示,因为是全部处理完成后进行显示,会有卡顿
if args.visualization:
pred_mask_final = np.array(pred_masks)
pred_mask_final = (np.argmax(pred_mask_final, axis=0).astype('uint8') + 1) * (
np.max(pred_mask_final, axis=0) > state['p'].seg_thr).astype('uint8')
COLORS = np.random.randint(128, 255, size=(object_num, 3), dtype="uint8")
COLORS = np.vstack([[0, 0, 0], COLORS]).astype("uint8")
mask = COLORS[pred_mask_final]
for f, image_file in enumerate(image_files):
output = ((0.4 * cv2.imread(image_file)) + (0.6 * mask[f,:,:,:])).astype("uint8")
cv2.imshow("mask", output)
cv2.waitKey(1)
logger.info('({:d}) Video: {:12s} Time: {:02.1f}s Speed: {:3.1f}fps'.format(
v_id, video['name'], toc, f*len(object_ids) / toc))
return multi_mean_iou, f*len(object_ids) / toc
def main():
# 获取命令行参数信息
global args, logger, v_id
args = parser.parse_args()
# 获取配置文件中配置信息:主要包括网络结构,超参数等
cfg = load_config(args)
# 初始化logxi信息,并将日志信息输入到磁盘文件中
init_log('global', logging.INFO)
if args.log != "":
add_file_handler('global', args.log, logging.INFO)
# 将相关的配置信息输入到日志文件中
logger = logging.getLogger('global')
logger.info(args)
# setup model
# 加载网络模型架构
if args.arch == 'Custom':
from custom import Custom
model = Custom(anchors=cfg['anchors'])
else:
parser.error('invalid architecture: {}'.format(args.arch))
# 加载网络模型参数
if args.resume:
assert isfile(args.resume), '{} is not a valid file'.format(args.resume)
model = load_pretrain(model, args.resume)
# 使用评估模式,将drop等激活
model.eval()
# 硬件信息
device = torch.device('cuda' if (torch.cuda.is_available() and not args.cpu) else 'cpu')
model = model.to(device)
# 加载数据集 setup dataset
dataset = load_dataset(args.dataset)
# 这三种数据支持掩膜 VOS or VOT?
if args.dataset in ['DAVIS2016', 'DAVIS2017', 'ytb_vos'] and args.mask:
vos_enable = True # enable Mask output
else:
vos_enable = False
total_lost = 0 # VOT
iou_lists = [] # VOS
speed_list = []
# 对数据进行处理
for v_id, video in enumerate(dataset.keys(), start=1):
if args.video != '' and video != args.video:
continue
# true 调用track_vos
if vos_enable:
# 如测试数据是['DAVIS2017', 'ytb_vos']时,会开启多目标跟踪
iou_list, speed = track_vos(model, dataset[video], cfg['hp'] if 'hp' in cfg.keys() else None,
args.mask, args.refine, args.dataset in ['DAVIS2017', 'ytb_vos'], device=device)
iou_lists.append(iou_list)
# False 调用track_vot
else:
lost, speed = track_vot(model, dataset[video], cfg['hp'] if 'hp' in cfg.keys() else None,
args.mask, args.refine, device=device)
total_lost += lost
speed_list.append(speed)
# report final result
if vos_enable:
for thr, iou in zip(thrs, np.mean(np.concatenate(iou_lists), axis=0)):
logger.info('Segmentation Threshold {:.2f} mIoU: {:.3f}'.format(thr, iou))
else:
logger.info('Total Lost: {:d}'.format(total_lost))
logger.info('Mean Speed: {:.2f} FPS'.format(np.mean(speed_list)))
if __name__ == '__main__':
main()
train_siammask.py
# --------------------------------------------------------
# 基础网络的训练
# --------------------------------------------------------
import argparse
import logging
import os
import cv2
import shutil
import time
import json
import math
import torch
from torch.utils.data import DataLoader
from utils.log_helper import init_log, print_speed, add_file_handler, Dummy
from utils.load_helper import load_pretrain, restore_from
from utils.average_meter_helper import AverageMeter
from datasets.siam_mask_dataset import DataSets
from utils.lr_helper import build_lr_scheduler
from tensorboardX import SummaryWriter
from utils.config_helper import load_config
from torch.utils.collect_env import get_pretty_env_info
torch.backends.cudnn.benchmark = True
parser = argparse.ArgumentParser(description='PyTorch Tracking SiamMask Training')
parser.add_argument('-j', '--workers', default=0, type=int, metavar='N',
help='number of data loading workers (default: 16)')
parser.add_argument('--epochs', default=50, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
help='manual epoch number (useful on restarts)')
parser.add_argument('-b', '--batch', default=64, type=int,
metavar='N', help='mini-batch size (default: 64)')
parser.add_argument('--lr', '--learning-rate', default=0.001, type=float,
metavar='LR', help='initial learning rate')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
help='momentum')
parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,
metavar='W', help='weight decay (default: 1e-4)')
parser.add_argument('--clip', default=10.0, type=float,
help='gradient clip value')
parser.add_argument('--print-freq', '-p', default=10, type=int,
metavar='N', help='print frequency (default: 10)')
parser.add_argument('--resume', default='', type=str, metavar='PATH',
help='path to latest checkpoint (default: none)')
parser.add_argument('--pretrained', dest='pretrained', default='',
help='use pre-trained model')
parser.add_argument('--config', dest='config', required=True,
help='hyperparameter of SiamMask in json format')
parser.add_argument('--arch', dest='arch', default='', choices=['Custom',],
help='architecture of pretrained model')
parser.add_argument('-l', '--log', default="log.txt", type=str,
help='log file')
parser.add_argument('-s', '--save_dir', default='snapshot', type=str,
help='save dir')
parser.add_argument('--log-dir', default='board', help='TensorBoard log dir')
best_acc = 0.
def collect_env_info():
"""
环境信息
:return:
"""
env_str = get_pretty_env_info()
env_str += "\n OpenCV ({})".format(cv2.__version__)
return env_str
def build_data_loader(cfg):
"""
获取数据集
:param cfg:
:return:
"""
logger = logging.getLogger('global')
logger.info("build train dataset") # train_dataset
# 获取训练集数据,包含数据增强的内容
train_set = DataSets(cfg['train_datasets'], cfg['anchors'], args.epochs)
# 对数据进行打乱处理
train_set.shuffle()
# 获取验证集数据,若为配置验证集数据则使用训练集数据替代
logger.info("build val dataset") # val_dataset
if not 'val_datasets' in cfg.keys():
cfg['val_datasets'] = cfg['train_datasets']
val_set = DataSets(cfg['val_datasets'], cfg['anchors'])
val_set.shuffle()
# DataLoader是Torch内置的方法,它允许使用多线程加速数据的读取
train_loader = DataLoader(train_set, batch_size=args.batch, num_workers=args.workers,
pin_memory=True, sampler=None)
val_loader = DataLoader(val_set, batch_size=args.batch, num_workers=args.workers,
pin_memory=True, sampler=None)
logger.info('build dataset done')
return train_loader, val_loader
def build_opt_lr(model, cfg, args, epoch):
"""
获取优化方法和学习率
:param model:
:param cfg:
:param args:
:param epoch:
:return:
"""
# 获取要训练的网络
backbone_feature = model.features.param_groups(cfg['lr']['start_lr'], cfg['lr']['feature_lr_mult'])
if len(backbone_feature) == 0:
# 获取要训练的rpn网络的参数
trainable_params = model.rpn_model.param_groups(cfg['lr']['start_lr'], cfg['lr']['rpn_lr_mult'], 'mask')
else:
# 获取基础网络,rpn和mask网络的训练参数
trainable_params = backbone_feature + \
model.rpn_model.param_groups(cfg['lr']['start_lr'], cfg['lr']['rpn_lr_mult']) + \
model.mask_model.param_groups(cfg['lr']['start_lr'], cfg['lr']['mask_lr_mult'])
# 随机梯度下降算法优化
optimizer = torch.optim.SGD(trainable_params, args.lr,
momentum=args.momentum,
weight_decay=args.weight_decay)
# 获取学习率
lr_scheduler = build_lr_scheduler(optimizer, cfg['lr'], epochs=args.epochs)
# 更新学习率
lr_scheduler.step(epoch)
# 返回优化器和学习率
return optimizer, lr_scheduler
def main():
"""
基础网络的训练
:return:
"""
global args, best_acc, tb_writer, logger
args = parser.parse_args()
# 初始化日志信息
init_log('global', logging.INFO)
if args.log != "":
add_file_handler('global', args.log, logging.INFO)
# 获取log信息
logger = logging.getLogger('global')
logger.info("\n" + collect_env_info())
logger.info(args)
# 获取配置信息
cfg = load_config(args)
logger.info("config \n{}".format(json.dumps(cfg, indent=4)))
if args.log_dir:
tb_writer = SummaryWriter(args.log_dir)
else:
tb_writer = Dummy()
# 构建数据集
train_loader, val_loader = build_data_loader(cfg)
# 加载训练网络
if args.arch == 'Custom':
from custom import Custom
model = Custom(pretrain=True, anchors=cfg['anchors'])
else:
exit()
logger.info(model)
# 加载预训练网络
if args.pretrained:
model = load_pretrain(model, args.pretrained)
# GPU版本
# model = model.cuda()
# dist_model = torch.nn.DataParallel(model, list(range(torch.cuda.device_count()))).cuda()
# 网络模型
dist_model = torch.nn.DataParallel(model)
# 模型参数的更新比例
if args.resume and args.start_epoch != 0:
model.features.unfix((args.start_epoch - 1) / args.epochs)
# 获取优化器和学习率的更新策略
optimizer, lr_scheduler = build_opt_lr(model, cfg, args, args.start_epoch)
# optionally resume from a checkpoint 加载模型
if args.resume:
assert os.path.isfile(args.resume), '{} is not a valid file'.format(args.resume)
model, optimizer, args.start_epoch, best_acc, arch = restore_from(model, optimizer, args.resume)
# GPU
# dist_model = torch.nn.DataParallel(model, list(range(torch.cuda.device_count()))).cuda()
dist_model = torch.nn.DataParallel(model)
logger.info(lr_scheduler)
logger.info('model prepare done')
# 模型训练
train(train_loader, dist_model, optimizer, lr_scheduler, args.start_epoch, cfg)
def train(train_loader, model, optimizer, lr_scheduler, epoch, cfg):
"""
模型训练
:param train_loader:训练数据
:param model:
:param optimizer:
:param lr_scheduler:
:param epoch:
:param cfg:
:return:
"""
global tb_index, best_acc, cur_lr, logger
# 获取当前的学习率
cur_lr = lr_scheduler.get_cur_lr()
logger = logging.getLogger('global')
#
avg = AverageMeter()
model.train()
# GPU
# model = model.cuda()
end = time.time()
def is_valid_number(x):
return not(math.isnan(x) or math.isinf(x) or x > 1e4)
num_per_epoch = len(train_loader.dataset) // args.epochs // args.batch
print("num_per_epoch",num_per_epoch)
start_epoch = epoch
epoch = epoch
# 获取每个batch的输入
for iter, input in enumerate(train_loader):
if epoch != iter // num_per_epoch + start_epoch: # next epoch
epoch = iter // num_per_epoch + start_epoch
# 创建存储路径
if not os.path.exists(args.save_dir): # makedir/save model
os.makedirs(args.save_dir)
# 存储训练结果
save_checkpoint({
'epoch': epoch,
'arch': args.arch,
'state_dict': model.module.state_dict(),
'best_acc': best_acc,
'optimizer': optimizer.state_dict(),
'anchor_cfg': cfg['anchors']
}, False,
os.path.join(args.save_dir, 'checkpoint_e%d.pth' % (epoch)),
os.path.join(args.save_dir, 'best.pth'))
if epoch == args.epochs:
return
# 更新优化器和学习方法
if model.module.features.unfix(epoch/args.epochs):
logger.info('unfix part model.')
optimizer, lr_scheduler = build_opt_lr(model.module, cfg, args, epoch)
# 获取当前学习率
lr_scheduler.step(epoch)
cur_lr = lr_scheduler.get_cur_lr()
logger.info('epoch:{}'.format(epoch))
# 更新日志
tb_index = iter
if iter % num_per_epoch == 0 and iter != 0:
for idx, pg in enumerate(optimizer.param_groups):
logger.info("epoch {} lr {}".format(epoch, pg['lr']))
tb_writer.add_scalar('lr/group%d' % (idx+1), pg['lr'], tb_index)
data_time = time.time() - end
avg.update(data_time=data_time)
# 输入数据
x = {
# GPU
# 'cfg': cfg,
# 'template': torch.autograd.Variable(input[0]).cuda(),
# 'search': torch.autograd.Variable(input[1]).cuda(),
# 'label_cls': torch.autograd.Variable(input[2]).cuda(),
# 'label_loc': torch.autograd.Variable(input[3]).cuda(),
# 'label_loc_weight': torch.autograd.Variable(input[4]).cuda(),
# 'label_mask': torch.autograd.Variable(input[6]).cuda(),
# 'label_mask_weight': torch.autograd.Variable(input[7]).cuda(),
'cfg': cfg,
'template': torch.autograd.Variable(input[0]),
'search': torch.autograd.Variable(input[1]),
'label_cls': torch.autograd.Variable(input[2]),
'label_loc': torch.autograd.Variable(input[3]),
'label_loc_weight': torch.autograd.Variable(input[4]),
'label_mask': torch.autograd.Variable(input[6]),
'label_mask_weight': torch.autograd.Variable(input[7]),
}
# 输出数据
outputs = model(x)
# 计算损失函数
rpn_cls_loss, rpn_loc_loss, rpn_mask_loss = torch.mean(outputs['losses'][0]), torch.mean(outputs['losses'][1]), torch.mean(outputs['losses'][2])
# 计算精度
mask_iou_mean, mask_iou_at_5, mask_iou_at_7 = torch.mean(outputs['accuracy'][0]), torch.mean(outputs['accuracy'][1]), torch.mean(outputs['accuracy'][2])
# 获取分类,回归和分割所占的比例
cls_weight, reg_weight, mask_weight = cfg['loss']['weight']
# 计算总损失
loss = rpn_cls_loss * cls_weight + rpn_loc_loss * reg_weight + rpn_mask_loss * mask_weight
# 将梯度置零
optimizer.zero_grad()
# 反向传播
loss.backward()
if cfg['clip']['split']:
torch.nn.utils.clip_grad_norm_(model.module.features.parameters(), cfg['clip']['feature'])
torch.nn.utils.clip_grad_norm_(model.module.rpn_model.parameters(), cfg['clip']['rpn'])
torch.nn.utils.clip_grad_norm_(model.module.mask_model.parameters(), cfg['clip']['mask'])
else:
torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip) # gradient clip
if is_valid_number(loss.item()):
optimizer.step()
siammask_loss = loss.item()
batch_time = time.time() - end
# 参数更新
avg.update(batch_time=batch_time, rpn_cls_loss=rpn_cls_loss, rpn_loc_loss=rpn_loc_loss,
rpn_mask_loss=rpn_mask_loss, siammask_loss=siammask_loss,
mask_iou_mean=mask_iou_mean, mask_iou_at_5=mask_iou_at_5, mask_iou_at_7=mask_iou_at_7)
# 参数写入tensorboard
tb_writer.add_scalar('loss/cls', rpn_cls_loss, tb_index)
tb_writer.add_scalar('loss/loc', rpn_loc_loss, tb_index)
tb_writer.add_scalar('loss/mask', rpn_mask_loss, tb_index)
tb_writer.add_scalar('mask/mIoU', mask_iou_mean, tb_index)
tb_writer.add_scalar('mask/AP@.5', mask_iou_at_5, tb_index)
tb_writer.add_scalar('mask/AP@.7', mask_iou_at_7, tb_index)
end = time.time()
# 日志输出
if (iter + 1) % args.print_freq == 0:
logger.info('Epoch: [{0}][{1}/{2}] lr: {lr:.6f}\t{batch_time:s}\t{data_time:s}'
'\t{rpn_cls_loss:s}\t{rpn_loc_loss:s}\t{rpn_mask_loss:s}\t{siammask_loss:s}'
'\t{mask_iou_mean:s}\t{mask_iou_at_5:s}\t{mask_iou_at_7:s}'.format(
epoch+1, (iter + 1) % num_per_epoch, num_per_epoch, lr=cur_lr, batch_time=avg.batch_time,
data_time=avg.data_time, rpn_cls_loss=avg.rpn_cls_loss, rpn_loc_loss=avg.rpn_loc_loss,
rpn_mask_loss=avg.rpn_mask_loss, siammask_loss=avg.siammask_loss, mask_iou_mean=avg.mask_iou_mean,
mask_iou_at_5=avg.mask_iou_at_5,mask_iou_at_7=avg.mask_iou_at_7))
print_speed(iter + 1, avg.batch_time.avg, args.epochs * num_per_epoch)
def save_checkpoint(state, is_best, filename='checkpoint.pth', best_file='model_best.pth'):
torch.save(state, filename)
if is_best:
shutil.copyfile(filename, best_file)
if __name__ == '__main__':
main()
train_siammask_refine.py
# --------------------------------------------------------
# SiamMask
# Licensed under The MIT License
# Written by Qiang Wang (wangqiang2015 at ia.ac.cn)
# 该文件中训练掩膜细化网络,与基础网络的训练类似,加载的模型是包含refine的模型
# --------------------------------------------------------
import argparse
import logging
import os
import cv2
import shutil
import time
import json
import math
import torch
from torch.utils.data import DataLoader
from utils.log_helper import init_log, print_speed, add_file_handler, Dummy
from utils.load_helper import load_pretrain, restore_from
from utils.average_meter_helper import AverageMeter
from datasets.siam_mask_dataset import DataSets
from utils.lr_helper import build_lr_scheduler
from tensorboardX import SummaryWriter
from utils.config_helper import load_config
from torch.utils.collect_env import get_pretty_env_info
torch.backends.cudnn.benchmark = True
parser = argparse.ArgumentParser(description='PyTorch Tracking Training')
parser.add_argument('-j', '--workers', default=16, type=int, metavar='N',
help='number of data loading workers (default: 16)')
parser.add_argument('--epochs', default=50, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
help='manual epoch number (useful on restarts)')
parser.add_argument('-b', '--batch', default=64, type=int,
metavar='N', help='mini-batch size (default: 64)')
parser.add_argument('--lr', '--learning-rate', default=0.001, type=float,
metavar='LR', help='initial learning rate')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
help='momentum')
parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,
metavar='W', help='weight decay (default: 1e-4)')
parser.add_argument('--clip', default=10.0, type=float,
help='gradient clip value')
parser.add_argument('--print-freq', '-p', default=10, type=int,
metavar='N', help='print frequency (default: 10)')
parser.add_argument('--resume', default='', type=str, metavar='PATH',
help='path to latest checkpoint (default: none)')
parser.add_argument('--pretrained', dest='pretrained', default='',
help='use pre-trained model')
parser.add_argument('--config', dest='config', required=True,
help='hyperparameter of SiamRPN in json format')
parser.add_argument('--arch', dest='arch', default='', choices=['Custom',''],
help='architecture of pretrained model')
parser.add_argument('-l', '--log', default="log.txt", type=str,
help='log file')
parser.add_argument('-s', '--save_dir', default='snapshot', type=str,
help='save dir')
parser.add_argument('--log-dir', default='board', help='TensorBoard log dir')
best_acc = 0.
def collect_env_info():
env_str = get_pretty_env_info()
env_str += "\n OpenCV ({})".format(cv2.__version__)
return env_str
def build_data_loader(cfg):
logger = logging.getLogger('global')
logger.info("build train dataset") # train_dataset
train_set = DataSets(cfg['train_datasets'], cfg['anchors'], args.epochs)
train_set.shuffle()
logger.info("build val dataset") # val_dataset
if not 'val_datasets' in cfg.keys():
cfg['val_datasets'] = cfg['train_datasets']
val_set = DataSets(cfg['val_datasets'], cfg['anchors'])
val_set.shuffle()
train_loader = DataLoader(train_set, batch_size=args.batch, num_workers=args.workers,
pin_memory=True, sampler=None)
val_loader = DataLoader(val_set, batch_size=args.batch, num_workers=args.workers,
pin_memory=True, sampler=None)
logger.info('build dataset done')
return train_loader, val_loader
def build_opt_lr(model, cfg, args, epoch):
'''
对模型参数进行优化
:param model:
:param cfg:
:param args:
:param epoch:
:return:
'''
trainable_params = model.mask_model.param_groups(cfg['lr']['start_lr'], cfg['lr']['mask_lr_mult']) + \
model.refine_model.param_groups(cfg['lr']['start_lr'], cfg['lr']['mask_lr_mult'])
# 随机梯度下降算法:
optimizer = torch.optim.SGD(trainable_params, args.lr,
momentum=args.momentum,
weight_decay=args.weight_decay)
# 获取学习率
lr_scheduler = build_lr_scheduler(optimizer, cfg['lr'], epochs=args.epochs)
# 更新学习率
lr_scheduler.step(epoch)
return optimizer, lr_scheduler
def main():
global args, best_acc, tb_writer, logger
args = parser.parse_args()
init_log('global', logging.INFO)
if args.log != "":
add_file_handler('global', args.log, logging.INFO)
logger = logging.getLogger('global')
logger.info("\n" + collect_env_info())
logger.info(args)
cfg = load_config(args)
logger.info("config \n{}".format(json.dumps(cfg, indent=4)))
if args.log_dir:
tb_writer = SummaryWriter(args.log_dir)
else:
tb_writer = Dummy()
# build dataset
train_loader, val_loader = build_data_loader(cfg)
if args.arch == 'Custom':
from custom import Custom
model = Custom(anchors=cfg['anchors'])
else:
exit()
# model = models.__dict__[args.arch](anchors=cfg['anchors'])
logger.info(model)
if args.pretrained:
model = load_pretrain(model, args.pretrained)
# GPU
# model = model.cuda()
# dist_model = torch.nn.DataParallel(model, list(range(torch.cuda.device_count()))).cuda()
dist_model = torch.nn.DataParallel(model)
if args.resume and args.start_epoch != 0:
model.features.unfix((args.start_epoch - 1) / args.epochs)
optimizer, lr_scheduler = build_opt_lr(model, cfg, args, args.start_epoch)
# optionally resume from a checkpoint
if args.resume:
assert os.path.isfile(args.resume), '{} is not a valid file'.format(args.resume)
model, optimizer, args.start_epoch, best_acc, arch = restore_from(model, optimizer, args.resume)
# GPU
# dist_model = torch.nn.DataParallel(model, list(range(torch.cuda.device_count()))).cuda()
dist_model = torch.nn.DataParallel(model)
logger.info(lr_scheduler)
logger.info('model prepare done')
train(train_loader, dist_model, optimizer, lr_scheduler, args.start_epoch, cfg)
def BNtoFixed(m):
class_name = m.__class__.__name__
if class_name.find('BatchNorm') != -1:
m.eval()
def train(train_loader, model, optimizer, lr_scheduler, epoch, cfg):
global tb_index, best_acc, cur_lr, logger
cur_lr = lr_scheduler.get_cur_lr()
logger = logging.getLogger('global')
avg = AverageMeter()
model.train()
model.module.features.eval()
model.module.rpn_model.eval()
model.module.features.apply(BNtoFixed)
model.module.rpn_model.apply(BNtoFixed)
model.module.mask_model.train()
model.module.refine_model.train()
# GPU
# model = model.cuda()
end = time.time()
def is_valid_number(x):
return not(math.isnan(x) or math.isinf(x) or x > 1e4)
num_per_epoch = len(train_loader.dataset) // args.epochs // args.batch
start_epoch = epoch
epoch = epoch
for iter, input in enumerate(train_loader):
if epoch != iter // num_per_epoch + start_epoch: # next epoch
epoch = iter // num_per_epoch + start_epoch
if not os.path.exists(args.save_dir): # makedir/save model
os.makedirs(args.save_dir)
save_checkpoint({
'epoch': epoch,
'arch': args.arch,
'state_dict': model.module.state_dict(),
'best_acc': best_acc,
'optimizer': optimizer.state_dict(),
'anchor_cfg': cfg['anchors']
}, False,
os.path.join(args.save_dir, 'checkpoint_e%d.pth' % (epoch)),
os.path.join(args.save_dir, 'best.pth'))
if epoch == args.epochs:
return
optimizer, lr_scheduler = build_opt_lr(model.module, cfg, args, epoch)
lr_scheduler.step(epoch)
cur_lr = lr_scheduler.get_cur_lr()
logger.info('epoch:{}'.format(epoch))
tb_index = iter
if iter % num_per_epoch == 0 and iter != 0:
for idx, pg in enumerate(optimizer.param_groups):
logger.info("epoch {} lr {}".format(epoch, pg['lr']))
tb_writer.add_scalar('lr/group%d' % (idx+1), pg['lr'], tb_index)
data_time = time.time() - end
avg.update(data_time=data_time)
x = {
# GPU
# 'cfg': cfg,
# 'template': torch.autograd.Variable(input[0]).cuda(),
# 'search': torch.autograd.Variable(input[1]).cuda(),
# 'label_cls': torch.autograd.Variable(input[2]).cuda(),
# 'label_loc': torch.autograd.Variable(input[3]).cuda(),
# 'label_loc_weight': torch.autograd.Variable(input[4]).cuda(),
# 'label_mask': torch.autograd.Variable(input[6]).cuda(),
# 'label_mask_weight': torch.autograd.Variable(input[7]).cuda(),
'cfg': cfg,
'template': torch.autograd.Variable(input[0]),
'search': torch.autograd.Variable(input[1]),
'label_cls': torch.autograd.Variable(input[2]),
'label_loc': torch.autograd.Variable(input[3]),
'label_loc_weight': torch.autograd.Variable(input[4]),
'label_mask': torch.autograd.Variable(input[6]),
'label_mask_weight': torch.autograd.Variable(input[7]),
}
outputs = model(x)
rpn_cls_loss, rpn_loc_loss, rpn_mask_loss = torch.mean(outputs['losses'][0]), torch.mean(outputs['losses'][1]), torch.mean(outputs['losses'][2])
mask_iou_mean, mask_iou_at_5, mask_iou_at_7 = torch.mean(outputs['accuracy'][0]), torch.mean(outputs['accuracy'][1]), torch.mean(outputs['accuracy'][2])
cls_weight, reg_weight, mask_weight = cfg['loss']['weight']
loss = rpn_cls_loss * cls_weight + rpn_loc_loss * reg_weight + rpn_mask_loss * mask_weight
optimizer.zero_grad()
loss.backward()
if cfg['clip']['split']:
torch.nn.utils.clip_grad_norm_(model.module.features.parameters(), cfg['clip']['feature'])
torch.nn.utils.clip_grad_norm_(model.module.rpn_model.parameters(), cfg['clip']['rpn'])
torch.nn.utils.clip_grad_norm_(model.module.mask_model.parameters(), cfg['clip']['mask'])
torch.nn.utils.clip_grad_norm_(model.module.refine_model.parameters(), cfg['clip']['mask'])
else:
torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip) # gradient clip
if is_valid_number(loss.item()):
optimizer.step()
siammask_loss = loss.item()
batch_time = time.time() - end
avg.update(batch_time=batch_time, rpn_cls_loss=rpn_cls_loss, rpn_loc_loss=rpn_loc_loss,
rpn_mask_loss=rpn_mask_loss, siammask_loss=siammask_loss,
mask_iou_mean=mask_iou_mean, mask_iou_at_5=mask_iou_at_5, mask_iou_at_7=mask_iou_at_7)
tb_writer.add_scalar('loss/cls', rpn_cls_loss, tb_index)
tb_writer.add_scalar('loss/loc', rpn_loc_loss, tb_index)
tb_writer.add_scalar('loss/mask', rpn_mask_loss, tb_index)
tb_writer.add_scalar('mask/mIoU', mask_iou_mean, tb_index)
tb_writer.add_scalar('mask/AP@.5', mask_iou_at_5, tb_index)
tb_writer.add_scalar('mask/AP@.7', mask_iou_at_7, tb_index)
end = time.time()
if (iter + 1) % args.print_freq == 0:
logger.info('Epoch: [{0}][{1}/{2}] lr: {lr:.6f}\t{batch_time:s}\t{data_time:s}'
'\t{rpn_cls_loss:s}\t{rpn_loc_loss:s}\t{rpn_mask_loss:s}\t{siammask_loss:s}'
'\t{mask_iou_mean:s}\t{mask_iou_at_5:s}\t{mask_iou_at_7:s}'.format(
epoch+1, (iter + 1) % num_per_epoch, num_per_epoch, lr=cur_lr, batch_time=avg.batch_time,
data_time=avg.data_time, rpn_cls_loss=avg.rpn_cls_loss, rpn_loc_loss=avg.rpn_loc_loss,
rpn_mask_loss=avg.rpn_mask_loss, siammask_loss=avg.siammask_loss, mask_iou_mean=avg.mask_iou_mean,
mask_iou_at_5=avg.mask_iou_at_5,mask_iou_at_7=avg.mask_iou_at_7))
print_speed(iter + 1, avg.batch_time.avg, args.epochs * num_per_epoch)
def save_checkpoint(state, is_best, filename='checkpoint.pth', best_file='model_best.pth'):
torch.save(state, filename)
if is_best:
shutil.copyfile(filename, best_file)
if __name__ == '__main__':
main()
videoDemo.py
from __future__ import division
import flask
import argparse
import numpy as np
import cv2
from os.path import join, isdir, isfile
from utils.load_helper import load_pretrain
import torch
from utils.config_helper import load_config
from tools.test import *
# 1 创建解析对象
parser = argparse.ArgumentParser(description='PyTorch Tracking Demo')
# 2 添加参数
# 2.1 resume:梗概
parser.add_argument('--resume', default='SiamMask.pth', type=str,
metavar='PATH', help='path to latest checkpoint (default: none)')
# 2.2 config配置
parser.add_argument('--config', dest='config', default='config.json',
help='hyper-parameter of SiamMask in json format')
# 2.3 处理的图像的序列
parser.add_argument('--base_path', default='../../data/car', help='datasets')
# 2.4 硬件信息
parser.add_argument('--cpu', action='store_true', help='cpu mode')
# 3 解析参数
args = parser.parse_args()
def process_vedio(vedio_path, initRect):
"""
视频处理
:param vedio_path:视频路径
:param initRect: 跟踪目标的初始位置
:return:
"""
# 1. 设置设备信息 Setup device
# 有GPU时选择GPU,否则使用CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 默认优化运行效率
torch.backends.cudnn.benchmark = True
# 2. 模型设置 Setup Model
# 2.1 将命令行参数解析出来
cfg = load_config(args)
# 2.2 custom是构建的网络,否则引用model中的网络结构
from custom import Custom
siammask = Custom(anchors=cfg['anchors'])
# 2.3 判断是否存在模型的权重文件
if args.resume:
assert isfile(args.resume), 'Please download {} first.'.format(args.resume)
siammask = load_pretrain(siammask, args.resume)
# 在运行推断前,需要调用 model.eval() 函数,以将 dropout 层 和 batch normalization 层设置为评估模式(非训练模式).
# to(device)将张量复制到GPU上,之后的计算将在GPU上运行
siammask.eval().to(device)
# 首帧跟踪目标的位置
x, y, w, h = initRect
print(x)
VeryBig = 999999999 # 用于将视频框调整到最大
Cap = cv2.VideoCapture(vedio_path) # 设置读取摄像头
ret, frame = Cap.read() # 读取帧
ims = [frame] # 把frame放入列表格式的frame, 因为原文是将每帧图片放入列表
im = frame
f = 0
target_pos = np.array([x + w / 2, y + h / 2])
target_sz = np.array([w, h])
state = siamese_init(im, target_pos, target_sz, siammask, cfg['hp']) # init tracker"
middlepath = "../data/middle.mp4"
outpath = "../data/output.mp4"
vediowriter = cv2.VideoWriter(middlepath, cv2.VideoWriter_fourcc('M', 'P', '4', 'V'), 10, (320, 240))
while (True):
tic = cv2.getTickCount()
ret, im = Cap.read() # 逐个提取frame
if (ret == False):
break;
state = siamese_track(state, im, mask_enable=True, refine_enable=True) # track
location = state['ploygon'].flatten()
mask = state['mask'] > state['p'].seg_thr
im[:, :, 2] = (mask > 0) * 255 + (mask == 0) * im[:, :, 2]
cv2.polylines(im, [np.int0(location).reshape((-1, 1, 2))], True, (0, 255, 0), 3)
vediowriter.write(im)
cv2.imshow('SiamMask', im)
key = cv2.waitKey(1)
if key > 0:
break
f = f + 1
vediowriter.release()
return
if __name__ == '__main__':
process_vedio('../data/car.mp4', [162, 121, 28, 25])
SiamMask_master\utils
anchors.py
# --------------------------------------------------------
# anchor处理帮助类
# --------------------------------------------------------
import numpy as np
import math
from utils.bbox_helper import center2corner, corner2center
class Anchors:
"""
anchors类
"""
def __init__(self, cfg):
self.stride = 8 # anchors的范围
self.ratios = [0.33, 0.5, 1, 2, 3] # anchors的宽高比
self.scales = [8] # anchor的尺度
self.round_dight = 0 # 兼容python2和python3的数据的舍入
self.image_center = 0 # 基础锚点的中心在原点
self.size = 0
self.anchor_density = 1 # anchor的密度,即每隔几个像素产生锚点
self.__dict__.update(cfg)
self.anchor_num = len(self.scales) * len(self.ratios) * (self.anchor_density**2) # anchor的数目
self.anchors = None # 某一像素点的anchor,维度为(anchor_num*4)in single position (anchor_num*4)
self.all_anchors = None # 所有像素点的anchor,维度为(2*(4*anchor_num*h*w)):其中包含两种数据格式的锚点表示方法:[x1, y1, x2, y2]和[cx, cy, w, h]:in all position 2*(4*anchor_num*h*w)
self.generate_anchors()
def generate_anchors(self):
"""
生成anchor
:return:
"""
# 生成全零数组存储锚点
self.anchors = np.zeros((self.anchor_num, 4), dtype=np.float32)
# 生成anchor的大小
size = self.stride * self.stride
count = 0
# 用检测区域的长度除以步长得到生成anchor的点
anchors_offset = self.stride / self.anchor_density
# 计算生成anchor的点相对于原点的偏移
anchors_offset = np.arange(self.anchor_density)*anchors_offset
anchors_offset = anchors_offset - np.mean(anchors_offset)
# 利用meshgrid生成x,y方向的偏移值
x_offsets, y_offsets = np.meshgrid(anchors_offset, anchors_offset)
# 遍历生成锚点的点,生成对应的anchor
for x_offset, y_offset in zip(x_offsets.flatten(), y_offsets.flatten()):
# 遍历宽高比
for r in self.ratios:
# 生成anchor的宽高
if self.round_dight > 0:
ws = round(math.sqrt(size*1. / r), self.round_dight)
hs = round(ws * r, self.round_dight)
else:
ws = int(math.sqrt(size*1. / r))
hs = int(ws * r)
# 根据anchor的尺寸生成anchor
for s in self.scales:
w = ws * s
h = hs * s
self.anchors[count][:] = [-w*0.5+x_offset, -h*0.5+y_offset, w*0.5+x_offset, h*0.5+y_offset][:]
count += 1
def generate_all_anchors(self, im_c, size):
"""
生成整幅图像的anchors
:param im_c:图像的中心点
:param size:图像的尺寸
:return:
"""
if self.image_center == im_c and self.size == size:
return False
# 更新config中的内容
self.image_center = im_c
self.size = size
# anchor0 的xy 坐标,即 x 和 y 对称。
a0x = im_c - size // 2 * self.stride
# 生成anchor0的坐标
ori = np.array([a0x] * 4, dtype=np.float32)
# 以图像中心点为中心点的anchor
zero_anchors = self.anchors + ori
# 获取anchor0的坐标
x1 = zero_anchors[:, 0]
y1 = zero_anchors[:, 1]
x2 = zero_anchors[:, 2]
y2 = zero_anchors[:, 3]
x1, y1, x2, y2 = map(lambda x: x.reshape(self.anchor_num, 1, 1), [x1, y1, x2, y2])
cx, cy, w, h = corner2center([x1, y1, x2, y2])
# disp_x是[1, 1, size],disp_y是[1, size, 1]
disp_x = np.arange(0, size).reshape(1, 1, -1) * self.stride
disp_y = np.arange(0, size).reshape(1, -1, 1) * self.stride
# 得到整幅图像中anchor中心点的坐标
cx = cx + disp_x
cy = cy + disp_y
# 通过广播生成整幅图像的anchor broadcast
zero = np.zeros((self.anchor_num, size, size), dtype=np.float32)
cx, cy, w, h = map(lambda x: x + zero, [cx, cy, w, h])
x1, y1, x2, y2 = center2corner([cx, cy, w, h])
# 以中心点坐标,宽高和左上角、右下角坐标两种方式存储anchors
self.all_anchors = np.stack([x1, y1, x2, y2]), np.stack([cx, cy, w, h])
return True
if __name__ == '__main__':
anchors = Anchors(cfg={'stride':16, 'anchor_density': 2})
anchors.generate_all_anchors(im_c=255//2, size=(255-127)//16+1+8)
print(anchors.all_anchors)
# a = 1
average_meter_helper.py
# --------------------------------------------------------
# 计算和存储指标数据
# --------------------------------------------------------
import numpy as np
class Meter(object):
"指标数据"
def __init__(self, name, val, avg):
# 名称
self.name = name
# 值
self.val = val
# 平均值
self.avg = avg
def __repr__(self):
return "{name}: {val:.6f} ({avg:.6f})".format(
name=self.name, val=self.val, avg=self.avg
)
def __format__(self, *tuples, **kwargs):
return self.__repr__()
class AverageMeter(object):
"""计算平均值和当前值并进行存储"""
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
# 重置
self.val = {}
self.sum = {}
self.count = {}
def update(self, batch=1, **kwargs):
# 参数更新
val = {}
for k in kwargs:
# 遍历参数
val[k] = kwargs[k] / float(batch)
self.val.update(val)
for k in kwargs:
# 计算sum和count
if k not in self.sum:
self.sum[k] = 0
self.count[k] = 0
self.sum[k] += kwargs[k]
self.count[k] += batch
def __repr__(self):
s = ''
for k in self.sum:
s += self.format_str(k)
return s
def format_str(self, attr):
# 格式输出
return "{name}: {val:.6f} ({avg:.6f}) ".format(
name=attr,
val=float(self.val[attr]),
avg=float(self.sum[attr]) / self.count[attr])
def __getattr__(self, attr):
if attr in self.__dict__:
return super(AverageMeter, self).__getattr__(attr)
if attr not in self.sum:
# logger.warn("invalid key '{}'".format(attr))
print("invalid key '{}'".format(attr))
return Meter(attr, 0, 0)
return Meter(attr, self.val[attr], self.avg(attr))
def avg(self, attr):
return float(self.sum[attr]) / self.count[attr]
class IouMeter(object):
"""计算和存储IOU指标数据"""
def __init__(self, thrs, sz):
"初始化"
self.sz = sz
self.iou = np.zeros((sz, len(thrs)), dtype=np.float32)
self.thrs = thrs
self.reset()
def reset(self):
# 重置
self.iou.fill(0.)
self.n = 0
def add(self, output, target):
'添加交并比'
if self.n >= len(self.iou):
return
target, output = target.squeeze(), output.squeeze()
# 计算交并比
for i, thr in enumerate(self.thrs):
pred = output > thr
mask_sum = (pred == 1).astype(np.uint8) + (target > 0).astype(np.uint8)
# 并
intxn = np.sum(mask_sum == 2)
# 交
union = np.sum(mask_sum > 0)
if union > 0:
# 交并比
self.iou[self.n, i] = intxn / union
elif union == 0 and intxn == 0:
# 交并比为1
self.iou[self.n, i] = 1
self.n += 1
def value(self, s):
nb = max(int(np.sum(self.iou > 0)), 1)
iou = self.iou[:nb]
def is_number(s):
"判断是否为数值"
try:
float(s)
return True
except ValueError:
return False
if s == 'mean':
# 均值
res = np.mean(iou, axis=0)
elif s == 'median':
# 中位数
res = np.median(iou, axis=0)
elif is_number(s):
# 均值
res = np.sum(iou > float(s), axis=0) / float(nb)
return res
if __name__ == '__main__':
avg = AverageMeter()
avg.update(time=1.1, accuracy=.99)
avg.update(time=1.0, accuracy=.90)
print(avg)
print(avg.sum)
print(avg.time)
print(avg.time.avg)
print(avg.time.val)
print(avg.SS)
bbox_helper.py
# --------------------------------------------------------
# 矩形框处理帮助
# --------------------------------------------------------
import numpy as np
from collections import namedtuple
# 定义类型Corner: 左上角坐标和右下角坐标
Corner = namedtuple('Corner', 'x1 y1 x2 y2')
BBox = Corner
# 定义类型Center:中心点坐标和宽高
Center = namedtuple('Center', 'x y w h')
def corner2center(corner):
"""
左上角右下角坐标转换为中心坐标,宽高
:param corner: Corner or np.array 4*N
:return: Center or 4 np.array N
"""
# 判断输入数据是否为Corner
if isinstance(corner, Corner):
# 获取坐标数据
x1, y1, x2, y2 = corner
# 计算中心点坐标和宽高
return Center((x1 + x2) * 0.5, (y1 + y2) * 0.5, (x2 - x1), (y2 - y1))
else:
# 获取坐标
x1, y1, x2, y2 = corner[0], corner[1], corner[2], corner[3]
# 计算中心点坐标
x = (x1 + x2) * 0.5
y = (y1 + y2) * 0.5
# 计算宽高
w = x2 - x1
h = y2 - y1
return x, y, w, h
def center2corner(center):
"""
中心坐标,宽高转换为左上角右下角坐标
:param center: Center or np.array 4*N
:return: Corner or np.array 4*N
"""
# 判断数据是否为Center
if isinstance(center, Center):
# 获取坐标数据和宽高
x, y, w, h = center
# 计算Corner
return Corner(x - w * 0.5, y - h * 0.5, x + w * 0.5, y + h * 0.5)
else:
# 获取数据
x, y, w, h = center[0], center[1], center[2], center[3]
# 左上角坐标
x1 = x - w * 0.5
y1 = y - h * 0.5
# 右下角坐标
x2 = x + w * 0.5
y2 = y + h * 0.5
return x1, y1, x2, y2
def cxy_wh_2_rect(pos, sz):
"""
转换矩形框的表示方式
:param pos: 矩形框中心点坐标
:param sz: 矩形框大小:宽高
:return: 矩形框的左上角坐标,宽,高
"""
return np.array([pos[0]-sz[0]/2, pos[1]-sz[1]/2, sz[0], sz[1]]) # 0-index
def get_axis_aligned_bbox(region):
"""
将目标区域其最小外接矩形的形式:中心点坐标和宽,高的形式
:param region:
:return:中心点坐标,宽,高
"""
nv = region.size
# 若region是四角坐标,可能不平行于图像
if nv == 8:
# 计算中心点坐标
cx = np.mean(region[0::2])
cy = np.mean(region[1::2])
# 计算外接矩形的左上角坐标和右下角坐标
x1 = min(region[0::2])
x2 = max(region[0::2])
y1 = min(region[1::2])
y2 = max(region[1::2])
# 求L2范数
# 平行四边形面积
A1 = np.linalg.norm(region[0:2] - region[2:4]) * np.linalg.norm(region[2:4] - region[4:6])
# 外接矩形面积
A2 = (x2 - x1) * (y2 - y1)
s = np.sqrt(A1 / A2)
# 求宽和高
w = s * (x2 - x1) + 1
h = s * (y2 - y1) + 1
# region 是左上角坐标和宽高
else:
x = region[0]
y = region[1]
w = region[2]
h = region[3]
# 中心点坐标
cx = x+w/2
cy = y+h/2
return cx, cy, w, h
def aug_apply(bbox, param, shape, inv=False, rd=False):
"""
对矩形进行增强 apply augmentation
:param bbox: original bbox in image
:param param: augmentation param, shift/scale
:param shape: image shape, h, w, (c)
:param inv: inverse
:param rd: round bbox
:return: bbox(, param)
bbox: augmented bbox
param: real augmentation param
"""
if not inv:
# 获取中心坐标
center = corner2center(bbox)
original_center = center
real_param = {}
# 矩形缩放
if 'scale' in param:
# 获取缩放比例
scale_x, scale_y = param['scale']
imh, imw = shape[:2]
# 获取宽高
h, w = center.h, center.w
# 计算比例
scale_x = min(scale_x, float(imw) / w)
scale_y = min(scale_y, float(imh) / h)
# center.w *= scale_x
# center.h *= scale_y
# 计算中心
center = Center(center.x, center.y, center.w * scale_x, center.h * scale_y)
# 获取目标框(x1, y1, x2, y2 )
bbox = center2corner(center)
# 矩形平移
if 'shift' in param:
tx, ty = param['shift']
x1, y1, x2, y2 = bbox
imh, imw = shape[:2]
# 获取平移距离
tx = max(-x1, min(imw - 1 - x2, tx))
ty = max(-y1, min(imh - 1 - y2, ty))
bbox = Corner(x1 + tx, y1 + ty, x2 + tx, y2 + ty)
if rd:
bbox = Corner(*map(round, bbox))
current_center = corner2center(bbox)
# 缩放和平移参数
real_param['scale'] = current_center.w / original_center.w, current_center.h / original_center.h
real_param['shift'] = current_center.x - original_center.x, current_center.y - original_center.y
return bbox, real_param
else:
# 矩形框缩放
if 'scale' in param:
scale_x, scale_y = param['scale']
else:
scale_x, scale_y = 1., 1.
# 平移
if 'shift' in param:
tx, ty = param['shift']
else:
tx, ty = 0, 0
# 中心点坐标
center = corner2center(bbox)
center = Center(center.x - tx, center.y - ty, center.w / scale_x, center.h / scale_y)
return center2corner(center)
def IoU(rect1, rect2):
# 计算IOU
x1, y1, x2, y2 = rect1[0], rect1[1], rect1[2], rect1[3]
tx1, ty1, tx2, ty2 = rect2[0], rect2[1], rect2[2], rect2[3]
# 获取坐标的最大值和最小值
xx1 = np.maximum(tx1, x1)
yy1 = np.maximum(ty1, y1)
xx2 = np.minimum(tx2, x2)
yy2 = np.minimum(ty2, y2)
# 获取交的宽高
ww = np.maximum(0, xx2 - xx1)
hh = np.maximum(0, yy2 - yy1)
# rect1的面积
area = (x2-x1) * (y2-y1)
# rect2的面积
target_a = (tx2-tx1) * (ty2 - ty1)
# 交
inter = ww * hh
# 交并比
overlap = inter / (area + target_a - inter)
return overlap
benchmark_helper.py
# --------------------------------------------------------
# 数据集加载帮助类
# --------------------------------------------------------
from os.path import join, realpath, dirname, exists, isdir
from os import listdir
import logging
import glob
import numpy as np
import json
from collections import OrderedDict
def get_dataset_zoo():
"""
对data中的数据进行检测
:return:
"""
root = realpath(join(dirname(__file__), '../data'))
zoos = listdir(root)
# 检测条件
def valid(x):
y = join(root, x)
if not isdir(y): return False
return exists(join(y, 'list.txt')) \
or exists(join(y, 'train', 'meta.json')) \
or exists(join(y, 'ImageSets', '2016', 'val.txt')) \
or exists(join(y, 'ImageSets', '2017', 'test-dev.txt'))
# 检测data中的符合条件的返回
zoos = list(filter(valid, zoos))
return zoos
dataset_zoo = get_dataset_zoo()
# 加载数据集
def load_dataset(dataset):
# OrderedDict可对字典中的数据进行排序
info = OrderedDict()
# VOT数据
if 'VOT' in dataset:
# 基础路径
base_path = join(realpath(dirname(__file__)), '../data', dataset)
if not exists(base_path):
logging.error("Please download test dataset!!!")
exit()
# 获取路径列表:指明视频的路径
list_path = join(base_path, 'list.txt')
# 获取数据
with open(list_path) as f:
videos = [v.strip() for v in f.readlines()]
# 遍历数据
for video in videos:
# 视频路径
video_path = join(base_path, video)
image_path = join(video_path, '*.jpg')
image_files = sorted(glob.glob(image_path))
if len(image_files) == 0: # VOT2018
image_path = join(video_path, 'color', '*.jpg')
image_files = sorted(glob.glob(image_path))
gt_path = join(video_path, 'groundtruth.txt')
gt = np.loadtxt(gt_path, delimiter=',').astype(np.float64)
if gt.shape[1] == 4:
gt = np.column_stack((gt[:, 0], gt[:, 1], gt[:, 0], gt[:, 1] + gt[:, 3]-1,
gt[:, 0] + gt[:, 2]-1, gt[:, 1] + gt[:, 3]-1, gt[:, 0] + gt[:, 2]-1, gt[:, 1]))
info[video] = {'image_files': image_files, 'gt': gt, 'name': video}
# DAVIS
elif 'DAVIS' in dataset and 'TEST' not in dataset:
base_path = join(realpath(dirname(__file__)), '../data', 'DAVIS')
list_path = join(realpath(dirname(__file__)), '../data', 'DAVIS', 'ImageSets', dataset[-4:], 'val.txt')
with open(list_path) as f:
videos = [v.strip() for v in f.readlines()]
for video in videos:
info[video] = {}
info[video]['anno_files'] = sorted(glob.glob(join(base_path, 'Annotations/480p', video, '*.png')))
info[video]['image_files'] = sorted(glob.glob(join(base_path, 'JPEGImages/480p', video, '*.jpg')))
info[video]['name'] = video
# ytb_vos数据
elif 'ytb_vos' in dataset:
base_path = join(realpath(dirname(__file__)), '../data', 'ytb_vos', 'valid')
json_path = join(realpath(dirname(__file__)), '../data', 'ytb_vos', 'valid', 'meta.json')
meta = json.load(open(json_path, 'r'))
meta = meta['videos']
info = dict()
for v in meta.keys():
objects = meta[v]['objects']
frames = []
anno_frames = []
info[v] = dict()
for obj in objects:
frames += objects[obj]['frames']
anno_frames += [objects[obj]['frames'][0]]
frames = sorted(np.unique(frames))
info[v]['anno_files'] = [join(base_path, 'Annotations', v, im_f+'.png') for im_f in frames]
info[v]['anno_init_files'] = [join(base_path, 'Annotations', v, im_f + '.png') for im_f in anno_frames]
info[v]['image_files'] = [join(base_path, 'JPEGImages', v, im_f+'.jpg') for im_f in frames]
info[v]['name'] = v
info[v]['start_frame'] = dict()
info[v]['end_frame'] = dict()
for obj in objects:
start_file = objects[obj]['frames'][0]
end_file = objects[obj]['frames'][-1]
info[v]['start_frame'][obj] = frames.index(start_file)
info[v]['end_frame'][obj] = frames.index(end_file)
# 测试数据
elif 'TEST' in dataset:
base_path = join(realpath(dirname(__file__)), '../data', 'DAVIS2017TEST')
list_path = join(realpath(dirname(__file__)), '../data', 'DAVIS2017TEST', 'ImageSets', '2017', 'test-dev.txt')
with open(list_path) as f:
videos = [v.strip() for v in f.readlines()]
for video in videos:
info[video] = {}
info[video]['anno_files'] = sorted(glob.glob(join(base_path, 'Annotations/480p', video, '*.png')))
info[video]['image_files'] = sorted(glob.glob(join(base_path, 'JPEGImages/480p', video, '*.jpg')))
info[video]['name'] = video
else:
logging.error('Not support')
exit()
return info
config_helper.py
# --------------------------------------------------------
# 配置文件处理帮助
# --------------------------------------------------------
import json
from os.path import exists
def proccess_loss(cfg):
"""
解析配置文件中的loss函数
:param cfg:
:return:
"""
# 回归
if 'reg' not in cfg:
# 默认为L1Loss
cfg['reg'] = {'loss': 'L1Loss'}
else:
if 'loss' not in cfg['reg']:
cfg['reg']['loss'] = 'L1Loss'
# 分类
if 'cls' not in cfg:
cfg['cls'] = {'split': True}
# cls, reg, mask损失的比重
cfg['weight'] = cfg.get('weight', [1, 1, 36])
def add_default(conf, default):
# 默认设置
default.update(conf)
return default
def load_config(args):
"""
加载命令行中指定的配置文件中的信息
:param args:命令行解析结果
:return:json配置文件中的信息
"""
# 断言命令行中是否包含config,若包含则对其进行解析
assert exists(args.config), '"{}" not exists'.format(args.config)
config = json.load(open(args.config))
# deal with network 网络结构
if 'network' not in config:
print('Warning: network lost in config. This will be error in next version')
config['network'] = {}
if not args.arch:
raise Exception('no arch provided')
args.arch = config['network']['arch']
# deal with loss 损失函数
if 'loss' not in config:
config['loss'] = {}
proccess_loss(config['loss'])
# deal with lr 学习率
if 'lr' not in config:
config['lr'] = {}
default = {
'feature_lr_mult': 1.0,
'rpn_lr_mult': 1.0,
'mask_lr_mult': 1.0,
'type': 'log',
'start_lr': 0.03
}
default.update(config['lr'])
config['lr'] = default
# clip 命令行中的参数,是否进行裁剪
if 'clip' in config or 'clip' in args.__dict__:
if 'clip' not in config:
config['clip'] = {}
config['clip'] = add_default(config['clip'],
{'feature': args.clip, 'rpn': args.clip, 'split': False})
if config['clip']['feature'] != config['clip']['rpn']:
config['clip']['split'] = True
if not config['clip']['split']:
args.clip = config['clip']['feature']
return config
load_helper.py
# --------------------------------------------------------
# 模型加载帮助类
# --------------------------------------------------------
import torch
import logging
logger = logging.getLogger('global')
def check_keys(model, pretrained_state_dict):
"模型检查"
# 预训练模型中的keys
ckpt_keys = set(pretrained_state_dict.keys())
# 原始模型中的keys
model_keys = set(model.state_dict().keys())
# 预训练模型和原始模型包含的keys
used_pretrained_keys = model_keys & ckpt_keys
# 只在预训练模型中的keys
unused_pretrained_keys = ckpt_keys - model_keys
# 只在原始模型中的keys
missing_keys = model_keys - ckpt_keys
# 预训练模型丢失的keys大于0
if len(missing_keys) > 0:
logger.info('[Warning] missing keys: {}'.format(missing_keys))
logger.info('missing keys:{}'.format(len(missing_keys)))
# 只在原始模型中的keys
if len(unused_pretrained_keys) > 0:
logger.info('[Warning] unused_pretrained_keys: {}'.format(unused_pretrained_keys))
logger.info('unused checkpoint keys:{}'.format(len(unused_pretrained_keys)))
logger.info('used keys:{}'.format(len(used_pretrained_keys)))
assert len(used_pretrained_keys) > 0, 'load NONE from pretrained checkpoint'
return True
def remove_prefix(state_dict, prefix):
''' Old style model is stored with all names of parameters share common prefix 'module.' '''
logger.info('remove prefix \'{}\''.format(prefix))
f = lambda x: x.split(prefix, 1)[-1] if x.startswith(prefix) else x
return {f(key): value for key, value in state_dict.items()}
# 加载模型的权重文件
def load_pretrain(model, pretrained_path):
logger.info('load pretrained model from {}'.format(pretrained_path))
# 加载预训练模型
if not torch.cuda.is_available():
# CPU
pretrained_dict = torch.load(pretrained_path, map_location=lambda storage, loc: storage)
else:
# GPU
device = torch.cuda.current_device()
pretrained_dict = torch.load(pretrained_path, map_location=lambda storage, loc: storage.cuda(device))
# 去除前置网络
if "state_dict" in pretrained_dict.keys():
pretrained_dict = remove_prefix(pretrained_dict['state_dict'], 'module.')
else:
pretrained_dict = remove_prefix(pretrained_dict, 'module.')
try:
# 模型检测
check_keys(model, pretrained_dict)
except:
logger.info('[Warning]: using pretrain as features. Adding "features." as prefix')
new_dict = {}
for k, v in pretrained_dict.items():
k = 'features.' + k
new_dict[k] = v
pretrained_dict = new_dict
check_keys(model, pretrained_dict)
# 加载模型
model.load_state_dict(pretrained_dict, strict=False)
return model
def restore_from(model, optimizer, ckpt_path):
"模型恢复"
logger.info('restore from {}'.format(ckpt_path))
device = torch.cuda.current_device()
# 加载权重文件
ckpt = torch.load(ckpt_path, map_location=lambda storage, loc: storage.cuda(device))
epoch = ckpt['epoch']
best_acc = ckpt['best_acc']
arch = ckpt['arch']
ckpt_model_dict = remove_prefix(ckpt['state_dict'], 'module.')
check_keys(model, ckpt_model_dict)
# 加载模型
model.load_state_dict(ckpt_model_dict, strict=False)
# 模型检查
check_keys(optimizer, ckpt['optimizer'])
# 加载优化器
optimizer.load_state_dict(ckpt['optimizer'])
return model, optimizer, epoch, best_acc, arch
log_helper.py
# --------------------------------------------------------
# 日志帮助类
# --------------------------------------------------------
from __future__ import division
import os
import logging
import sys
import math
# 确定日志输出的文件,可定位到exe,pyc,pyo,py文件
if hasattr(sys, 'frozen'): # support for py2exe
_srcfile = "logging%s__init__%s" % (os.sep, __file__[-4:])
elif __file__[-4:].lower() in ['.pyc', '.pyo']:
_srcfile = __file__[:-4] + '.py'
else:
_srcfile = __file__
_srcfile = os.path.normcase(_srcfile)
# 创建日志的集合
logs = set()
class Filter:
def __init__(self, flag):
self.flag = flag
def filter(self, x): return self.flag
class Dummy:
def __init__(self, *arg, **kwargs):
pass
def __getattr__(self, arg):
def dummy(*args, **kwargs): pass
return dummy
# 多GPU并行运行时进行日志记录
def get_format(logger, level):
if 'SLURM_PROCID' in os.environ:
rank = int(os.environ['SLURM_PROCID'])
if level == logging.INFO:
logger.addFilter(Filter(rank == 0))
else:
rank = 0
format_str = '[%(asctime)s-rk{}-%(filename)s#%(lineno)3d] %(message)s'.format(rank)
formatter = logging.Formatter(format_str)
return formatter
def get_format_custom(logger, level):
if 'SLURM_PROCID' in os.environ:
rank = int(os.environ['SLURM_PROCID'])
if level == logging.INFO:
logger.addFilter(Filter(rank == 0))
else:
rank = 0
format_str = '[%(asctime)s-rk{}-%(message)s'.format(rank)
formatter = logging.Formatter(format_str)
return formatter
def init_log(name, level = logging.INFO, format_func=get_format):
if (name, level) in logs: return
logs.add((name, level))
logger = logging.getLogger(name)
logger.setLevel(level)
ch = logging.StreamHandler()
ch.setLevel(level)
formatter = format_func(logger, level)
ch.setFormatter(formatter)
logger.addHandler(ch)
return logger
def add_file_handler(name, log_file, level = logging.INFO):
# 初始化logger对象
logger = logging.getLogger(name)
# 将日志输出到文件中
fh = logging.FileHandler(log_file)
# 设置日志的格式
fh.setFormatter(get_format(logger, level))
# 处理日志
logger.addHandler(fh)
init_log('global')
def print_speed(i, i_time, n):
"""print_speed(index, index_time, total_iteration)"""
logger = logging.getLogger('global')
average_time = i_time
remaining_time = (n - i) * average_time
remaining_day = math.floor(remaining_time / 86400)
remaining_hour = math.floor(remaining_time / 3600 - remaining_day * 24)
remaining_min = math.floor(remaining_time / 60 - remaining_day * 1440 - remaining_hour * 60)
logger.info('Progress: %d / %d [%d%%], Speed: %.3f s/iter, ETA %d:%02d:%02d (D:H:M)\n' % (i, n, i/n*100, average_time, remaining_day, remaining_hour, remaining_min))
def find_caller():
"""
根据栈帧的信息获取日志输出的文件,函数及行号
:return:
"""
def current_frame():
"""
Return the frame object for the caller's stack frame
返回调用函数caller栈帧位置
:return:
"""
try:
raise Exception
except:
return sys.exc_info()[2].tb_frame.f_back
f = current_frame()
if f is not None:
f = f.f_back
rv = "(unknown file)", 0, "(unknown function)"
while hasattr(f, "f_code"):
co = f.f_code
filename = os.path.normcase(co.co_filename)
rv = (co.co_filename, f.f_lineno, co.co_name)
if filename == _srcfile:
f = f.f_back
continue
break
rv = list(rv)
rv[0] = os.path.basename(rv[0])
return rv
# 用于模型训练时的日志输出
class LogOnce:
def __init__(self):
self.logged = set()
self.logger = init_log('log_once', format_func=get_format_custom)
def log(self, strings):
fn, lineno, caller = find_caller()
key = (fn, lineno, caller, strings)
if key in self.logged:
return
self.logged.add(key)
message = "{filename:s}<{caller}>#{lineno:3d}] {strings}".format(filename=fn, lineno=lineno, strings=strings, caller=caller)
self.logger.info(message)
once_logger = LogOnce()
def log_once(strings):
once_logger.log(strings)
lr_helper.py
# --------------------------------------------------------
# 学习率lr更新
# --------------------------------------------------------
from __future__ import division
import numpy as np
import math
from torch.optim.lr_scheduler import _LRScheduler
import matplotlib.pyplot as plt
class LRScheduler(_LRScheduler):
"""
学习率更新策略
"""
def __init__(self, optimizer, last_epoch=-1):
# 若不存在"lr_spaces"返回异常,lr_spaces是lr更新序列
if 'lr_spaces' not in self.__dict__:
raise Exception('lr_spaces must be set in "LRSchduler"')
super(LRScheduler, self).__init__(optimizer, last_epoch)
def get_cur_lr(self):
"""
获取当前epoch的学习率
:return:
"""
return self.lr_spaces[self.last_epoch]
def get_lr(self):
"""
定义学习率的更新策略
:return:
"""
epoch = self.last_epoch
# 返回当前epoch优化器中学习率
return [self.lr_spaces[epoch] * pg['initial_lr'] / self.start_lr for pg in self.optimizer.param_groups]
def __repr__(self):
"""
返回学习率更新的序列
:return:
"""
return "({}) lr spaces: \n{}".format(self.__class__.__name__, self.lr_spaces)
class LogScheduler(LRScheduler):
"""
指数式更新学习率
"""
def __init__(self, optimizer, start_lr=0.03, end_lr=5e-4, epochs=50, last_epoch=-1, **kwargs):
self.start_lr = start_lr
self.end_lr = end_lr
self.epochs = epochs
# 指明初始值和终值,依据epochs生成学习率
self.lr_spaces = np.logspace(math.log10(start_lr), math.log10(end_lr), epochs)
super(LogScheduler, self).__init__(optimizer, last_epoch)
class StepScheduler(LRScheduler):
"""
步进式的更新学习率
"""
def __init__(self, optimizer, start_lr=0.01, end_lr=None, step=10, mult=0.1, epochs=50, last_epoch=-1, **kwargs):
"""
初始化
:param optimizer: 优化器
:param start_lr: 初始lr
:param end_lr: 终值lr
:param step: 经过step个epoch更新学习率
:param mult: 更新参数gamma
:param epochs:
:param last_epoch: 起始epoch
:param kwargs:
"""
# 若end_lr不为None
if end_lr is not None:
if start_lr is None:
# 根据end_lr求解start_lr,multshi
start_lr = end_lr / (mult ** (epochs // step))
else: # for warm up policy
# 计算mult
mult = math.pow(end_lr/start_lr, 1. / (epochs // step))
self.start_lr = start_lr
# 得到学习率的序列
self.lr_spaces = self.start_lr * (mult**(np.arange(epochs) // step))
self.mult = mult
# 没过step个epoch更新学习率
self._step = step
super(StepScheduler, self).__init__(optimizer, last_epoch)
class MultiStepScheduler(LRScheduler):
"""
多步长更新学习率
"""
def __init__(self, optimizer, start_lr=0.01, end_lr=None, steps=[10,20,30,40], mult=0.5, epochs=50, last_epoch=-1, **kwargs):
"""
:param optimizer: 优化器
:param start_lr: 起始学习率
:param end_lr: 终值学习率
:param steps: 学习率进行更新的步长序列
:param mult: 更新参数
:param epochs:
:param last_epoch: 起始epoch
:param kwargs:
"""
if end_lr is not None:
if start_lr is None:
# 计算start_lr
start_lr = end_lr / (mult ** (len(steps)))
else:
# 计算mult
mult = math.pow(end_lr/start_lr, 1. / len(steps))
self.start_lr = start_lr
# 获取lr_spaces
self.lr_spaces = self._build_lr(start_lr, steps, mult, epochs)
self.mult = mult
self.steps = steps
super(MultiStepScheduler, self).__init__(optimizer, last_epoch)
def _build_lr(self, start_lr, steps, mult, epochs):
"""
计算学习率列表
:param start_lr:
:param steps:
:param mult:
:param epochs:
:return:
"""
lr = [0] * epochs
lr[0] = start_lr
for i in range(1, epochs):
lr[i] = lr[i-1]
# 若i在steps中则修改学习率,否则学习率不变
if i in steps:
lr[i] *= mult
return np.array(lr, dtype=np.float32)
class LinearStepScheduler(LRScheduler):
"""
线性更新学习率
"""
def __init__(self, optimizer, start_lr=0.01, end_lr=0.005, epochs=50, last_epoch=-1, **kwargs):
self.start_lr = start_lr
self.end_lr = end_lr
# 生成学习率序列
self.lr_spaces = np.linspace(start_lr, end_lr, epochs)
super(LinearStepScheduler, self).__init__(optimizer, last_epoch)
class CosStepScheduler(LRScheduler):
"""
cos式的更新学习率
"""
def __init__(self, optimizer, start_lr=0.01, end_lr=0.005, epochs=50, last_epoch=-1, **kwargs):
self.start_lr = start_lr
self.end_lr = end_lr
# 获取学习率
self.lr_spaces = self._build_lr(start_lr, end_lr, epochs)
super(CosStepScheduler, self).__init__(optimizer, last_epoch)
def _build_lr(self, start_lr, end_lr, epochs):
"""
创建学习率列表
:param start_lr: 开始学习率
:param end_lr: 终值学习率
:param epochs: epoch
:return:
"""
# 将epochs转换为浮点型数据
index = np.arange(epochs).astype(np.float32)
# 更新学习率
lr = end_lr + (start_lr - end_lr) * (1. + np.cos(index * np.pi/ epochs)) * 0.5
return lr.astype(np.float32)
class WarmUPScheduler(LRScheduler):
"""
将不同的学习率更新方式进行连接
"""
def __init__(self, optimizer, warmup, normal, epochs=50, last_epoch=-1):
warmup = warmup.lr_spaces # [::-1]
normal = normal.lr_spaces
# 将两种更新方式进行连接
self.lr_spaces = np.concatenate([warmup, normal])
self.start_lr = normal[0]
super(WarmUPScheduler, self).__init__(optimizer, last_epoch)
# 学习率更新方式集合
LRs = {
'log': LogScheduler,
'step': StepScheduler,
'multi-step': MultiStepScheduler,
'linear': LinearStepScheduler,
'cos': CosStepScheduler}
def _build_lr_scheduler(optimizer, cfg, epochs=50, last_epoch=-1):
"""
根据配置信息完成学习率更新
:param optimizer:
:param cfg:
:param epochs:
:param last_epoch:
:return:
"""
# 默认为按LOG方式进行更新
if 'type' not in cfg:
cfg['type'] = 'log'
# 若更新方式不在LRs中则返回异常
if cfg['type'] not in LRs:
raise Exception('Unknown type of LR Scheduler "%s"'%cfg['type'])
# 返回学习率结果
return LRs[cfg['type']](optimizer, last_epoch=last_epoch, epochs=epochs, **cfg)
def _build_warm_up_scheduler(optimizer, cfg, epochs=50, last_epoch=-1):
"""
根据配置信息,按照warm_up方式完成学习率更新
:param optimizer:
:param cfg:
:param epochs:
:param last_epoch:
:return:
"""
# 获取第一种更新方式的epoch
warmup_epoch = cfg['warmup']['epoch']
# 构建学习率更新列表
# 将学习率增加
sc1 = _build_lr_scheduler(optimizer, cfg['warmup'], warmup_epoch, last_epoch)
# 学习率下降
sc2 = _build_lr_scheduler(optimizer, cfg, epochs - warmup_epoch, last_epoch)
# 返回连接后的结果
return WarmUPScheduler(optimizer, sc1, sc2, epochs, last_epoch)
def build_lr_scheduler(optimizer, cfg, epochs=50, last_epoch=-1):
"""
将上述两种方法进行整合
:param optimizer:
:param cfg:
:param epochs:
:param last_epoch:
:return:
"""
# 若配置信息中含有"warmup"
if 'warmup' in cfg:
return _build_warm_up_scheduler(optimizer, cfg, epochs, last_epoch)
else:
# 否则
return _build_lr_scheduler(optimizer, cfg, epochs, last_epoch)
if __name__ == '__main__':
import torch.nn as nn
from torch.optim import SGD
# 模型搭建
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv = nn.Conv2d(10, 10, kernel_size=3)
# 模型实例化
net = Net().parameters()
# 优化器
optimizer = SGD(net, lr=0.01)
# step更新机制
step = {
'type': 'step',
'start_lr': 0.01,
'step': 10,
'mult': 0.1
}
lr = build_lr_scheduler(optimizer, step)
print("test1")
print(lr)
plt.plot(lr.lr_spaces)
plt.grid()
plt.title("StepScheduler")
plt.xlabel("epochs")
plt.ylabel("lr")
plt.show()
# Linear更新机制
step = {
'type': 'linear',
'start_lr': 0.01,
'end_lr':0.005
}
lr = build_lr_scheduler(optimizer, step)
print("test1")
print(lr)
plt.plot(lr.lr_spaces)
plt.grid()
plt.title("linear")
plt.xlabel("epochs")
plt.ylabel("lr")
plt.show()
# log更新机制
log = {
'type': 'log',
'start_lr': 0.03,
'end_lr': 5e-4,
}
lr = build_lr_scheduler(optimizer, log)
print(lr)
plt.plot(lr.lr_spaces)
plt.grid()
plt.title("logScheduler")
plt.xlabel("epochs")
plt.ylabel("lr")
plt.show()
# multi-step更新机制
log = {
'type': 'multi-step',
"start_lr": 0.01,
"mult": 0.1,
"steps": [10, 15, 20]
}
lr = build_lr_scheduler(optimizer, log)
print(lr)
plt.plot(lr.lr_spaces)
plt.grid()
plt.title("MultiStepScheduler")
plt.xlabel("epochs")
plt.ylabel("lr")
plt.show()
# cos更新机制
cos = {
"type": 'cos',
'start_lr': 0.01,
'end_lr': 0.0005,
}
lr = build_lr_scheduler(optimizer, cos)
print(lr)
plt.plot(lr.lr_spaces)
plt.grid()
plt.title("CosScheduler")
plt.xlabel("epochs")
plt.ylabel("lr")
plt.show()
# warmup更新机制,先上升后下降
step = {
'type': 'step',
'start_lr': 0.001,
'end_lr': 0.03,
'step': 1,
}
warmup = log.copy()
warmup['warmup'] = step
warmup['warmup']['epoch'] = 10
lr = build_lr_scheduler(optimizer, warmup, epochs=55)
print(lr)
plt.plot(lr.lr_spaces)
plt.grid()
plt.title("WarmupScheduler")
plt.xlabel("epochs")
plt.ylabel("lr")
plt.show()
lr.step()
print(lr.last_epoch)
lr.step(5)
print(lr.last_epoch)
tracker_config.py
# --------------------------------------------------------
# 目标跟踪器参数设置
# --------------------------------------------------------
from __future__ import division
from utils.anchors import Anchors
# 跟踪器配置参数
class TrackerConfig(object):
penalty_k = 0.09
window_influence = 0.39
lr = 0.38
seg_thr = 0.3 # 分割阈值 for mask
windowing = 'cosine' # 对于较大的位移进行惩罚 to penalize large displacements [cosine/uniform]
# Params from the network architecture, have to be consistent with the training
exemplar_size = 127 # 跟踪目标模板的大小 input z size
instance_size = 255 # 跟踪实例的大小 input x size (search region)
total_stride = 8 #
out_size = 63 # for mask
base_size = 8
score_size = (instance_size-exemplar_size)//total_stride+1+base_size
context_amount = 0.5 # 跟踪目标的周边信息比例 context amount for the exemplar
ratios = [0.33, 0.5, 1, 2, 3] # anchors宽高比
scales = [8, ] # 尺度,即anchor的大小
anchor_num = len(ratios) * len(scales) # anchor的个数
round_dight = 0 #
anchor = []
def update(self, newparam=None, anchors=None):
"""
更新参数
:param newparam: 新的参数
:param anchors: anchors的参数
:return:
"""
# 新的参数直接添加到配置中
if newparam:
for key, value in newparam.items():
setattr(self, key, value)
# 添加anchors的参数
if anchors is not None:
# 若anchors是字典形式的将其转换为Anchors
if isinstance(anchors, dict):
anchors = Anchors(anchors)
# 更新到config中
if isinstance(anchors, Anchors):
self.total_stride = anchors.stride
self.ratios = anchors.ratios
self.scales = anchors.scales
self.round_dight = anchors.round_dight
self.renew()
def renew(self):
"""
更新配置信息
:return:
"""
# 分类尺寸
self.score_size = (self.instance_size - self.exemplar_size) // self.total_stride + 1 + self.base_size
# anchor数目
self.anchor_num = len(self.ratios) * len(self.scales)