siammask代码(1)demo.py阅读查询

# --------------------------------------------------------
# SiamMask
# Licensed under The MIT License
# Written by Qiang Wang (wangqiang2015 at ia.ac.cn)
# --------------------------------------------------------
import glob
from tools.test import *


//1.创建了解析对象
//https://www.cnblogs.com/demo-lv/p/12672100.html
parser = argparse.ArgumentParser(description='PyTorch Tracking Demo')
//2.添加以下参数
//参数1:resume:梗概
parser.add_argument('--resume', default='', type=str, required=True,
                    metavar='PATH',help='path to latest checkpoint (default: none)')
//参数2:config配置
parser.add_argument('--config', dest='config', default='config_davis.json',
                    help='hyper-parameter of SiamMask in json format')
//参数3:需要处理的图像序列
parser.add_argument('--base_path', default='../../data/tennis', help='datasets')
//参数4:cpu等硬件信息
parser.add_argument('--cpu', action='store_true', help='cpu mode')
//解析参数
args = parser.parse_args()

参数

if __name__ == '__main__':
    #配置设备
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    torch.backends.cudnn.benchmark = True

torch.cuda.is_available()——cuda是否可用
意思就是有GPU时选择GPU,没有就使用CPU
torch.backends.cudnn.benchmark = True
用途:参考知乎文章

设置 torch.backends.cudnn.benchmark=True 将会让程序在开始时花费一点额外时间,为整个网络的每个卷积层搜索最适合它的卷积实现算法,进而实现网络的加速。适用场景是网络结构固定(不是动态变化的),网络的输入形状(包括 batch size,图片大小,输入的通道)是不变的,其实也就是一般情况下都比较适用。反之,如果卷积层的设置一直变化,将会导致程序不停地做优化,反而会耗费更多的时间。

简单来说就是优化,提高效率


    # 配置模型
    cfg = load_config(args)
    from custom import Custom
    //   在siammask_base文件里面有custom.py,这里面有Custom类,导入它
    //   放代码在下面了
    siammask = Custom(anchors=cfg['anchors'])
    if args.resume:
        assert isfile(args.resume), 'Please download {} first.'.format(args.resume)
        siammask = load_pretrain(siammask, args.resume)

    siammask.eval().to(device)
    
    

load_config() 是util里的,通过语句from tools.test import *导入
用法参考
在这里利用这个函数将args里面的参数解析出来

上面提到的Custom类

`class Custom(SiamMask):
    def __init__(self, pretrain=False, **kwargs):
        super(Custom, self).__init__(**kwargs)
        self.features = ResDown(pretrain=pretrain)
        self.rpn_model = UP(anchor_num=self.anchor_num, feature_in=256, feature_out=256)
        self.mask_model = MaskCorr()

    def template(self, template):
        self.zf = self.features(template)

    def track(self, search):
        search = self.features(search)
        rpn_pred_cls, rpn_pred_loc = self.rpn(self.zf, search)
        return rpn_pred_cls, rpn_pred_loc

    def track_mask(self, search):
        search = self.features(search)
        rpn_pred_cls, rpn_pred_loc = self.rpn(self.zf, search)
        pred_mask = self.mask(self.zf, search)
        return rpn_pred_cls, rpn_pred_loc, pred_mask`
 if args.resume:
        assert isfile(args.resume), 'Please download {} first.'.format(args.resume)
        siammask = load_pretrain(siammask, args.resume)

这个是为了判断是否存在模型的权重文件
.to(device)
将所有最开始读取数据时的tensor变量copy一份到device所指定的GPU上去,之后的运算都在GPU上进行。
参考:pytorch .to(device)

eval() 函数
eval() 函数用来执行一个字符串表达式,并返回表达式的值。
在这里,经过查阅: 调用 model.eval() 函数,是为了将 dropout 层 和 batch normalization 层设置为评估模式(非训练模式).


    # Parse Image file
    img_files = sorted(glob.glob(join(args.base_path, '*.jp*')))//sort排序
    ims = [cv2.imread(imf) for imf in img_files]

这段是在读取图片序列

    # Select ROI选择目标区域
    cv2.namedWindow("SiamMask", cv2.WND_PROP_FULLSCREEN)
    # cv2.setWindowProperty("SiamMask", cv2.WND_PROP_FULLSCREEN, cv2.WINDOW_FULLSCREEN)
    //cv2.namedWindow(‘窗口标题’,默认参数)
    //https://blog.csdn.net/xykenny/article/details/90513480
   //这里将目标框用矩形左上角坐标x,y、宽w、高h的形式来进行表达
    try:
        init_rect = cv2.selectROI('SiamMask', ims[0], False, False)
        x, y, w, h = init_rect
    except:
        exit()

    toc = 0
    for f, im in enumerate(ims): 
     //enumerate() 函数用于将一个可遍历的数据对象(如列表、元组或字符串)组合为一个索引序列,
   	//同时列出数据和数据下标,一般用在 for 循环当中
        tic = cv2.getTickCount()
         //cv2.getTickCount(): 获得时钟次数,跟性能评价有关
          //https://www.cnblogs.com/silence-cho/p/10926248.html
        if f == 0:  # init  初始化
            # 目标位置
            target_pos = np.array([x + w / 2, y + h / 2])
            # 目标的大小
            target_sz = np.array([w, h])
            # 对目标追踪进行初始化
            state = siamese_init(im, target_pos, target_sz, siammask, cfg['hp'], device=device)  # init tracker
        elif f > 0:  # tracking   追踪过程
        	# 目标追踪过程中进行参数state的更新
            state = siamese_track(state, im, mask_enable=True, refine_enable=True, device=device)  # track
            # location ——确定追踪目标的位置
            location = state['ploygon'].flatten()
            # 生成目标分割的mask掩码
            mask = state['mask'] > state['p'].seg_thr
			
			# 这里是将mask掩码显示在图像上面,就是视频中那一块红色
            im[:, :, 2] = (mask > 0) * 255 + (mask == 0) * im[:, :, 2]
            # 绘制目标的位置信息,就是画个框
            cv2.polylines(im, [np.int0(location).reshape((-1, 1, 2))], True, (0, 255, 0), 3)
            cv2.imshow('SiamMask', im)
            key = cv2.waitKey(1) //cv2.waitKey()在没有按键按下的时候返回值为-1,如果有按键值返回按键值
            if key > 0:         //若有按键按下就跳出循环
                break

        toc += cv2.getTickCount() - tic
    toc /= cv2.getTickFrequency()  //cv2.getTickFrequency():获得时钟频率 (每秒振动次数)
    fps = f / toc
    print('SiamMask Time: {:02.1f}s Speed: {:3.1f}fps (with visulization!)'.format(toc, fps))
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值