python parse添加参数demo

本文介绍了Python脚本中使用argparse模块时,`action='store_true'`参数的作用。通过示例说明了如何通过命令行选项来设置变量的真假值,以及在不指定参数时的默认行为。建议在运行前通过打印变量值来确认其状态,确保程序按预期运行。
摘要由CSDN通过智能技术生成
# Copyright 2019 Xilinx Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import argparse
import logging
import os
import sys
if os.environ["W_QUANT"]=='1':
    import pytorch_nndct
    from pytorch_nndct.apis import torch_quantizer, dump_xmodel

import torch
from torch import nn

import network
from core.config import opt, update_config
from core.loader import get_data_provider
from core.solver import Solver
from ipdb import set_trace

FORMAT = '[%(levelname)s]: %(message)s'
logging.basicConfig(
    level=logging.INFO,
    format=FORMAT,
    stream=sys.stdout
)


def test(args):
    logging.info('======= user config ======')
    logging.info(print(args))
    logging.info('======= end ======')

    train_data, test_data, num_query, num_class = get_data_provider(opt, args.dataset, args.dataset_root)

    net = getattr(network, opt.network.name)(opt.network.backbone, num_class, opt.network.last_stride)
    checkpoint = torch.load(args.load_model, map_location=opt.device)
    for i in checkpoint:
        if 'classifier' in i:
            continue
        net.state_dict()[i].copy_(checkpoint[i])
    logging.info('load model checkpoint: {}'.format(args.load_model))
    
    if args.device=='gpu' and args.quant_mode=='float':
        net = nn.DataParallel(net).to(opt.device)
    net = net.to(opt.device)

    resize_wh = opt.aug.resize_size
    x = torch.randn(1,3,resize_wh[0],resize_wh[1]).to(opt.device)
    if args.quant_mode == 'float':
        quant_model = net
    else:
        quantizer = torch_quantizer(args.quant_mode, net, (x), output_dir=args.output_path, device=opt.device)
        quant_model = quantizer.quant_model.to(opt.device)
    quant_model.eval()
    mod = Solver(opt, quant_model)
    mod.test_func(test_data, num_query)
     
    if args.quant_mode == 'calib':
        quantizer.export_quant_config()
    if args.quant_mode == 'test' and args.dump_xmodel:
        dump_xmodel(output_dir=args.output_path, deploy_check=True)
   
def main():
    parser = argparse.ArgumentParser(description='reid model testing')
    parser.add_argument('--dataset', type=str, default = 'market1501', 
                        help = 'set the dataset for test')
    parser.add_argument('--dataset_root', type=str, default = '../../data/market1501',
      
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值