# --------------------------------------------------------
# SiamMask
# Licensed under The MIT License
# Written by Qiang Wang (wangqiang2015 at ia.ac.cn)
# --------------------------------------------------------
import glob
from tools.test import *
//1.创建了解析对象
//https://www.cnblogs.com/demo-lv/p/12672100.html
parser = argparse.ArgumentParser(description='PyTorch Tracking Demo')
//2.添加以下参数
//参数1:resume:梗概
parser.add_argument('--resume', default='', type=str, required=True,
metavar='PATH',help='path to latest checkpoint (default: none)')
//参数2:config配置
parser.add_argument('--config', dest='config', default='config_davis.json',
help='hyper-parameter of SiamMask in json format')
//参数3:需要处理的图像序列
parser.add_argument('--base_path', default='../../data/tennis', help='datasets')
//参数4:cpu等硬件信息
parser.add_argument('--cpu', action='store_true', help='cpu mode')
//解析参数
args = parser.parse_args()
参数
if __name__ == '__main__':
#配置设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.backends.cudnn.benchmark = True
torch.cuda.is_available()
——cuda是否可用
意思就是有GPU时选择GPU,没有就使用CPU
torch.backends.cudnn.benchmark = True
用途:参考知乎文章
设置 torch.backends.cudnn.benchmark=True 将会让程序在开始时花费一点额外时间,为整个网络的每个卷积层搜索最适合它的卷积实现算法,进而实现网络的加速。适用场景是网络结构固定(不是动态变化的),网络的输入形状(包括 batch size,图片大小,输入的通道)是不变的,其实也就是一般情况下都比较适用。反之,如果卷积层的设置一直变化,将会导致程序不停地做优化,反而会耗费更多的时间。
简单来说就是优化,提高效率
# 配置模型
cfg = load_config(args)
from custom import Custom
// 在siammask_base文件里面有custom.py,这里面有Custom类,导入它
// 放代码在下面了
siammask = Custom(anchors=cfg['anchors'])
if args.resume:
assert isfile(args.resume), 'Please download {} first.'.format(args.resume)
siammask = load_pretrain(siammask, args.resume)
siammask.eval().to(device)
load_config()
是util里的,通过语句from tools.test import *
导入
用法参考
在这里利用这个函数将args里面的参数解析出来
上面提到的Custom类
`class Custom(SiamMask):
def __init__(self, pretrain=False, **kwargs):
super(Custom, self).__init__(**kwargs)
self.features = ResDown(pretrain=pretrain)
self.rpn_model = UP(anchor_num=self.anchor_num, feature_in=256, feature_out=256)
self.mask_model = MaskCorr()
def template(self, template):
self.zf = self.features(template)
def track(self, search):
search = self.features(search)
rpn_pred_cls, rpn_pred_loc = self.rpn(self.zf, search)
return rpn_pred_cls, rpn_pred_loc
def track_mask(self, search):
search = self.features(search)
rpn_pred_cls, rpn_pred_loc = self.rpn(self.zf, search)
pred_mask = self.mask(self.zf, search)
return rpn_pred_cls, rpn_pred_loc, pred_mask`
if args.resume:
assert isfile(args.resume), 'Please download {} first.'.format(args.resume)
siammask = load_pretrain(siammask, args.resume)
这个是为了判断是否存在模型的权重文件
.to(device)
将所有最开始读取数据时的tensor变量copy一份到device所指定的GPU上去,之后的运算都在GPU上进行。
参考:pytorch .to(device)
eval()
函数
eval() 函数用来执行一个字符串表达式,并返回表达式的值。
在这里,经过查阅: 调用 model.eval() 函数,是为了将 dropout 层 和 batch normalization 层设置为评估模式(非训练模式).
# Parse Image file
img_files = sorted(glob.glob(join(args.base_path, '*.jp*')))//sort排序
ims = [cv2.imread(imf) for imf in img_files]
这段是在读取图片序列
# Select ROI选择目标区域
cv2.namedWindow("SiamMask", cv2.WND_PROP_FULLSCREEN)
# cv2.setWindowProperty("SiamMask", cv2.WND_PROP_FULLSCREEN, cv2.WINDOW_FULLSCREEN)
//cv2.namedWindow(‘窗口标题’,默认参数)
//https://blog.csdn.net/xykenny/article/details/90513480
//这里将目标框用矩形左上角坐标x,y、宽w、高h的形式来进行表达
try:
init_rect = cv2.selectROI('SiamMask', ims[0], False, False)
x, y, w, h = init_rect
except:
exit()
toc = 0
for f, im in enumerate(ims):
//enumerate() 函数用于将一个可遍历的数据对象(如列表、元组或字符串)组合为一个索引序列,
//同时列出数据和数据下标,一般用在 for 循环当中
tic = cv2.getTickCount()
//cv2.getTickCount(): 获得时钟次数,跟性能评价有关
//https://www.cnblogs.com/silence-cho/p/10926248.html
if f == 0: # init 初始化
# 目标位置
target_pos = np.array([x + w / 2, y + h / 2])
# 目标的大小
target_sz = np.array([w, h])
# 对目标追踪进行初始化
state = siamese_init(im, target_pos, target_sz, siammask, cfg['hp'], device=device) # init tracker
elif f > 0: # tracking 追踪过程
# 目标追踪过程中进行参数state的更新
state = siamese_track(state, im, mask_enable=True, refine_enable=True, device=device) # track
# location ——确定追踪目标的位置
location = state['ploygon'].flatten()
# 生成目标分割的mask掩码
mask = state['mask'] > state['p'].seg_thr
# 这里是将mask掩码显示在图像上面,就是视频中那一块红色
im[:, :, 2] = (mask > 0) * 255 + (mask == 0) * im[:, :, 2]
# 绘制目标的位置信息,就是画个框
cv2.polylines(im, [np.int0(location).reshape((-1, 1, 2))], True, (0, 255, 0), 3)
cv2.imshow('SiamMask', im)
key = cv2.waitKey(1) //cv2.waitKey()在没有按键按下的时候返回值为-1,如果有按键值返回按键值
if key > 0: //若有按键按下就跳出循环
break
toc += cv2.getTickCount() - tic
toc /= cv2.getTickFrequency() //cv2.getTickFrequency():获得时钟频率 (每秒振动次数)
fps = f / toc
print('SiamMask Time: {:02.1f}s Speed: {:3.1f}fps (with visulization!)'.format(toc, fps))