这次为SiamMask调用VOT写一个调用的程序,在SiamMask工程文件夹下新建test文件夹,程序写在这里:
为了方便,我在parser的各元素中添加了默认值,另外,大家如果想换例子的话就修改文件头部的name变量:
import glob
from tools.test import *
import sys
import os
name = "bag"
sys.path.append("../experiments/siammask/")
parser = argparse.ArgumentParser(description='PyTorch Tracking Demo')
parser.add_argument('--resume', default='../experiments/siammask/SiamMask_VOT.pth',help='path to latest checkpoint (default: none)')
parser.add_argument('--config', dest='config', default='../experiments/siammask/config_vot.json',
help='hyper-parameter of SiamMask in json format')
parser.add_argument('--base_path', default='../data/VOT2016/'+name+'/', help='datasets')
args = parser.parse_args()
def num_read_directory(directory_name):
filename1=os.listdir(directory_name)
return len(filename1)
def read_directory(directory_name,num):
filename1=os.listdir(directory_name)
img = cv2.imread(directory_name + "/" + filename1[num],cv2.IMREAD_COLOR)
return img
if __name__ == '__main__':
# Setup device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.backends.cudnn.benchmark = True
# Setup Model
cfg = load_config(args)
from custom import Custom
siammask = Custom(anchors=cfg['anchors'])
if args.resume:
assert isfile(args.resume), '{} is not a valid file'.format(args.resume)
siammask = load_pretrain(siammask, args.resume)
siammask.eval().to(device)
# Parse Image file
img_files = sorted(glob.glob(join(args.base_path, '*.jp*')))
ims = [cv2.imread(imf) for imf in img_files]
num_image=num_read_directory(args.base_path)
f = open(args.base_path+'groundtruth.txt')
frames = np.zeros([num_image, 8])
j = 0
for i in f:
frames[j, :] = i.split(',')
j += 1
f.close()
# Select ROI
cv2.namedWindow("SiamMask", cv2.WND_PROP_FULLSCREEN)
# cv2.setWindowProperty("SiamMask", cv2.WND_PROP_FULLSCREEN, cv2.WINDOW_FULLSCREEN)
try:
init_rect = cv2.selectROI('SiamMask', ims[0], False, False)
x, y, w, h = init_rect
print([x, y, w, h])
except:
exit()
toc = 0
fourcc = cv2.VideoWriter_fourcc(*'XVID')
out = cv2.VideoWriter(args.base_path+ "/SiamMask_out.avi", fourcc, 24,
(ims[0].shape[1], ims[0].shape[0]))
for f, im in enumerate(ims):
tic = cv2.getTickCount()
if f == 0: # init
target_pos = np.array([x + w / 2, y + h / 2])
target_sz = np.array([w, h])
state = siamese_init(im, target_pos, target_sz, siammask, cfg['hp']) # init tracker
elif f > 0: # tracking
state = siamese_track(state, im, mask_enable=True, refine_enable=True) # track
location = state['ploygon'].flatten()
mask = state['mask'] > state['p'].seg_thr
im[:, :, 2] = (mask > 0) * 255 + (mask == 0) * im[:, :, 2]
cv2.polylines(im, [np.int0(location).reshape((-1, 1, 2))], True, (0, 255, 0), 3)
cv2.putText(im, "SiamMask", (5, 20), cv2.FONT_HERSHEY_COMPLEX_SMALL, 1, (0, 0, 255), 2)
cv2.imshow('SiamMask', im)
out.write(im)
key = cv2.waitKey(1)
if key > 0:
break
toc += cv2.getTickCount() - tic
toc /= cv2.getTickFrequency()
fps = f / toc
print('SiamMask Time: {:02.1f}s Speed: {:3.1f}fps (with visulization!)'.format(toc, fps))
这里边进行模板初始化时间还是采用人工选择区域,而在groundtruth.txt中读出来的是一个斜体的矩形框,这一点需要修改原本的siammask初始化程序,暂时还没修改,官方也没说怎么修改。
下边放一下效果吧: