YOLOX系列二 -- tools/demo.py详解

YOLOX系列二

tools/demo.py代码详解



前言

本文主要对demo.py中的代码进行解释。


一、代码详解

1.网络参数设置,pycharm直接运行demo.py

项目目录结构

img文件夹用于存放测试图片,tools、YOLOX_outputs为测试输出结果。
# 网络参数设置
def make_parser():
    parser = argparse.ArgumentParser("YOLOX Demo!")
    parser.add_argument(                                                         # 检测对象类型
        "--demo", default="image", help="demo type, eg. image, video and webcam"
    )
    parser.add_argument("-expn", "--experiment-name", type=str, default=None)
    parser.add_argument("-n", "--name", type=str, default=None, help="model name")

    parser.add_argument(                                                       #检测对象路径
        "--path", default="../img", help="path to images or video"
    )
    parser.add_argument("--camid", type=int, default=0, help="webcam demo camera id")
    parser.add_argument(
        "--save_result",
        default="true",
        # action="store_true",
        help="whether to save the inference result of image/video",
    )

    # exp file
    parser.add_argument(                # 模型文件py路径
        "-f",
        "--exp_file",
        default="../exps/default/yolox_x.py",
        type=str,
        help="pls input your experiment description file",
    )
    parser.add_argument("-c", "--ckpt", default="../weights/yolox_x.pth", type=str, help="ckpt for eval")   #权重路径
    parser.add_argument(
        "--device",
        default="cpu",
        type=str,
        help="device to run our model, can either be cpu or gpu",
    )
    parser.add_argument("--conf", default=0.3, type=float, help="test conf")            #conf
    parser.add_argument("--nms", default=0.45, type=float, help="test nms threshold")    #nms
    parser.add_argument("--tsize", default=640, type=int, help="test img size")        #图片尺寸
    parser.add_argument(
        "--fp16",
        dest="fp16",
        default=False,
        action="store_true",
        help="Adopting mix precision evaluating.",
    )
    parser.add_argument(
        "--legacy",
        dest="legacy",
        default=False,
        action="store_true",
        help="To be compatible with older versions",
    )
    parser.add_argument(
        "--fuse",
        dest="fuse",
        default=False,
        action="store_true",
        help="Fuse conv and bn for testing.",
    )
    parser.add_argument(
        "--trt",
        dest="trt",
        default=False,
        action="store_true",
        help="Using TensorRT model for testing.",
    )
    return parser

通过在py文件中进行参数设置,也可以直接对demo.py进行运行,而不需要通过在命令行输入的方式运行demo.py。

--demo 用于设置检测对象类型,image, video, webcam
--path 检测对象路径,如果为文件夹,则对文件夹下所有复合条件的对象都进行检测
                    如果为单个图片,则只检测该图片
--save_result 设置为default='true' ,表示默认保存检测结果
-f     设置使用的模型,如yolox_x
-c     对应模型的权重路径
--device cpu还是gpu
--conf  置信度
--nms   NMS
--tsize 网络输入大小
其余参数设置默认即可,不需要修改

1.# 获取图片(文件夹中所有图片)get_image_list(path)

def get_image_list(path):
    image_names = []
    for maindir, subdir, file_name_list in os.walk(path):
        for filename in file_name_list:
            apath = os.path.join(maindir, filename)
            ext = os.path.splitext(apath)[1]
            if ext in IMAGE_EXT:
                image_names.append(apath)
    return image_names

get_image_list(path)函数用于获取网络输入,返回值为输入路径下(path)图片的数组。
在image_demo(predictor, vis_folder, path, current_time, save_result)函数中被调用。


2.预测参数初始化Predictor(object)

# 预测参数
class Predictor(object):
    def __init__(
        self,
        model,
        exp,
        cls_names=COCO_CLASSES,
        trt_file=None,
        decoder=None,
        device="cpu",
        fp16=False,
        legacy=False,
    ):
        self.model = model
        self.cls_names = cls_names
        self.decoder = decoder
        self.num_classes = exp.num_classes
        self.confthre = exp.test_conf
        self.nmsthre = exp.nmsthre
        self.test_size = exp.test_size
        self.device = device
        self.fp16 = fp16
        self.preproc = ValTransform(legacy=legacy)
        if trt_file is not None:
            from torch2trt import TRTModule
            
            model_trt = TRTModule()
            model_trt.load_state_dict(torch.load(trt_file))

            x = torch.ones(1, 3, exp.test_size[0], exp.test_size[1]).cuda()
            self.model(x)
            self.model = model_trt

3.图片(帧)检测 inference(self, img)

    def inference(self, img):
        '''
        图片检测
        Args:
            img: 待检测视频帧

        Returns:
            outputs:检测结果
            img_info:id,file_name,height,weight,raw_img,ratio(--tsize长比上img长或者--tsize宽比上img宽的最小值)

        '''

        img_info = {"id": 0}
        if isinstance(img, str):
            img_info["file_name"] = os.path.basename(img)
            img = cv2.imread(img)
        else:
            img_info["file_name"] = None

        height, width = img.shape[:2]
        img_info["height"] = height
        img_info["width"] = width
        img_info["raw_img"] = img

        ratio = min(self.test_size[0] / img.shape[0], self.test_size[1] / img.shape[1])
        img_info["ratio"] = ratio

        img, _ = self.preproc(img, None, self.test_size)
        img = torch.from_numpy(img).unsqueeze(0)
        img = img.float()
        if self.device == "gpu":
            img = img.cuda()
            if self.fp16:
                img = img.half()  # to FP16

        with torch.no_grad():
            t0 = time.time()
            outputs = self.model(img)       # 加载模型进行预测
            if self.decoder is not None:
                outputs = self.decoder(outputs, dtype=outputs.type())
            outputs = postprocess(
                outputs, self.num_classes, self.confthre,
                self.nmsthre, class_agnostic=True
            )
            logger.info("Infer time: {:.4f}s".format(time.time() - t0))
        return outputs, img_info

outputs为一个tensor,也可以理解为二维数组。


4. visual(self, output, img_info, cls_conf=0.35)

    def visual(self, output, img_info, cls_conf=0.35):
        '''

        Args:
            output: 图片检测结果
            img_info: 图片信息
            cls_conf: 置信度

        Returns:
            vis_res:处理完后的图片(画框等等)

        '''
        ratio = img_info["ratio"]
        img = img_info["raw_img"]
        if output is None:
            return img
        output = output.cpu()

        bboxes = output[:, 0:4]    #预测框坐标

        # preprocessing: resize, 获取原始图片中检测框尺寸
        bboxes /= ratio           

        cls = output[:, 6]        # 类别
        scores = output[:, 4] * output[:, 5]  # 得分 

        vis_res = vis(img, bboxes, scores, cls, cls_conf, self.cls_names)
        return vis_res

visual(self, output, img_info, cls_conf=0.35)用于将预测结果可视化到图片上,其中引用了vis(img, bboxes, scores, cls, cls_conf, self.cls_names)函数,在tools/yolox/utils/visualize.py中:

def vis(img, boxes, scores, cls_ids, conf=0.5, class_names=None):

    for i in range(len(boxes)):                                                          # 遍历检测结果
        box = boxes[i]
        cls_id = int(cls_ids[i])
        score = scores[i]
        if score < conf:
            continue
        x0 = int(box[0])
        y0 = int(box[1])
        x1 = int(box[2])
        y1 = int(box[3])

        color = (_COLORS[cls_id] * 255).astype(np.uint8).tolist()                            # 框颜色设置
        text = '{}:{:.1f}%'.format(class_names[cls_id], score * 100)                         # 文本格式
        txt_color = (0, 0, 0) if np.mean(_COLORS[cls_id]) > 0.5 else (255, 255, 255)         # 文本颜色设置
        font = cv2.FONT_HERSHEY_SIMPLEX                                                      # 字体格式

        txt_size = cv2.getTextSize(text, font, 0.4, 1)[0]
        cv2.rectangle(img, (x0, y0), (x1, y1), color, 2)                                     # 画框

        txt_bk_color = (_COLORS[cls_id] * 255 * 0.7).astype(np.uint8).tolist()
        cv2.rectangle(
            img,
            (x0, y0 + 1),
            (x0 + txt_size[0] + 1, y0 + int(1.5*txt_size[1])),
            txt_bk_color,
            -1
        )
        cv2.putText(img, text, (x0, y0 + txt_size[1]), font, 0.4, txt_color, thickness=1)    # 图片上打印文字

    return img

5. image_demo(predictor, vis_folder, path, current_time, save_result)

# 图片检测函数
def image_demo(predictor, vis_folder, path, current_time, save_result):
    if os.path.isdir(path):
        files = get_image_list(path)          # 获取待检测图片
    else:
        files = [path]
    files.sort()
    for image_name in files:
        outputs, img_info = predictor.inference(image_name)       # 调用inference()函数进行图片预测
        result_image = predictor.visual(outputs[0], img_info, predictor.confthre)  # 预测结果可视化
        if save_result:
            save_folder = os.path.join(
                vis_folder, time.strftime("%Y_%m_%d_%H_%M_%S", current_time)   # 保存文件名
            )
            os.makedirs(save_folder, exist_ok=True)              # 创建保存文件夹
            save_file_name = os.path.join(save_folder, os.path.basename(image_name))
            logger.info("Saving detection result in {}".format(save_file_name))
            cv2.imwrite(save_file_name, result_image)   # 保存图片
        ch = cv2.waitKey(0)
        if ch == 27 or ch == ord("q") or ch == ord("Q"):
            break

image_demo()函数用于对图片进行检测:

  1. get_image_list()获取待检测图片
  2. predictor.inference()对图片进行预测
  3. predictor.visual() 用于预测结果可视化
  4. cv2.imwrite()保存图片

6.imageflow_demo(predictor, vis_folder, current_time, args)

# 视频检测函数
def imageflow_demo(predictor, vis_folder, current_time, args):
    #设置视频路径
    cap = cv2.VideoCapture(args.path if args.demo == "video" else args.camid)  # 读取视频
    width = cap.get(cv2.CAP_PROP_FRAME_WIDTH)  # float
    height = cap.get(cv2.CAP_PROP_FRAME_HEIGHT)  # float
    fps = cap.get(cv2.CAP_PROP_FPS)
    save_folder = os.path.join(                                        # 保存文件名
        vis_folder, time.strftime("%Y_%m_%d_%H_%M_%S", current_time)
    )
    os.makedirs(save_folder, exist_ok=True)                            #建立文件夹
    # 保存视频文件路径
    if args.demo == "video":
        save_path = os.path.join(save_folder, args.path.split("/")[-1])
    else:
        save_path = os.path.join(save_folder, "camera.mp4")
    logger.info(f"video save_path is {save_path}")
    vid_writer = cv2.VideoWriter(
        save_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, (int(width), int(height))
    )
    # 开始检测视频帧
    '''
    outputs[]检测结果,type为tensor
    其中每个output[6]为(x1,y1,x2,y2,conf,nms,cls)
    predictor.inference 进行帧检测
    predictor.visual   进行帧的画框
    '''
    while True:
        ret_val, frame = cap.read()
        if ret_val:
            outputs, img_info = predictor.inference(frame)
            result_frame = predictor.visual(outputs[0], img_info, predictor.confthre)
            if args.save_result:
                vid_writer.write(result_frame)
            ch = cv2.waitKey(1)
            if ch == 27 or ch == ord("q") or ch == ord("Q"):
                break
        else:
            break


7.main(exp, args)

def main(exp, args):
    if not args.experiment_name:
        args.experiment_name = exp.exp_name

    file_name = os.path.join(exp.output_dir, args.experiment_name)
    os.makedirs(file_name, exist_ok=True)

    vis_folder = None
    if args.save_result:
        vis_folder = os.path.join(file_name, "vis_res")
        os.makedirs(vis_folder, exist_ok=True)

    if args.trt:
        args.device = "gpu"

    logger.info("Args: {}".format(args))

    if args.conf is not None:
        exp.test_conf = args.conf
    if args.nms is not None:
        exp.nmsthre = args.nms
    if args.tsize is not None:
        exp.test_size = (args.tsize, args.tsize)

    model = exp.get_model()
    logger.info("Model Summary: {}".format(get_model_info(model, exp.test_size)))

    if args.device == "gpu":
        model.cuda()
        if args.fp16:
            model.half()                                      # to FP16
    model.eval()

    if not args.trt:
        if args.ckpt is None:
            ckpt_file = os.path.join(file_name, "best_ckpt.pth")
        else:
            ckpt_file = args.ckpt
        logger.info("loading checkpoint")
        ckpt = torch.load(ckpt_file, map_location="cpu")
        # 加载模型权重
        model.load_state_dict(ckpt["model"])
        logger.info("loaded checkpoint done.")

    if args.fuse:
        logger.info("\tFusing model...")
        model = fuse_model(model)

    if args.trt:
        assert not args.fuse, "TensorRT model is not support model fusing!"
        trt_file = os.path.join(file_name, "model_trt.pth")
        assert os.path.exists(
            trt_file
        ), "TensorRT model is not found!\n Run python3 tools/trt.py first!"
        model.head.decode_in_inference = False
        decoder = model.head.decode_outputs
        logger.info("Using TensorRT to inference")
    else:
        trt_file = None
        decoder = None

    predictor = Predictor(model, exp, COCO_CLASSES, trt_file, decoder, args.device, args.fp16, args.legacy)
    current_time = time.localtime()

    # 检测设置
    if args.demo == "image":
        image_demo(predictor, vis_folder, args.path, current_time, args.save_result)
    elif args.demo == "video" or args.demo == "webcam":
        imageflow_demo(predictor, vis_folder, current_time, args)
  1. 获取网络参数
  2. 如果为图片使用image_demo()
  3. 如果为视频或者摄像头使用imageflow_demo()

8.main()函数

if __name__ == "__main__":
    args = make_parser().parse_args()
    exp = get_exp(args.exp_file, args.name)

    main(exp, args)

在这里插入图片描述


总结

以上就是本文要讲的内容,本文对tools/demo.py的代码进行了解析,对主要函数功能及其调用进行了解释。 若发现文章有误,欢迎指出。 有用的话点个赞吧,万分感谢!
  • 13
    点赞
  • 44
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 5
    评论
文件: import scrapy from demo1.items import Demo1Item import urllib from scrapy import log # BOSS直聘网站爬虫职位 class DemoSpider(scrapy.Spider): # 爬虫名, 启动爬虫时需要的参数*必填 name = 'demo' # 爬取域范围,允许爬虫在这个域名下进行爬取(可选) allowed_domains = ['zhipin.com'] # 爬虫需要的url start_urls = ['https://www.zhipin.com/c101280600/h_101280600/?query=测试'] def parse(self, response): node_list = response.xpath("//div[@class='job-primary']") # 用来存储所有的item字段 # items = [] for node in node_list: item = Demo1Item() # extract() 将xpath对象转换为Unicode字符串 href = node.xpath("./div[@class='info-primary']//a/@href").extract() job_title = node.xpath("./div[@class='info-primary']//a/div[@class='job-title']/text()").extract() salary = node.xpath("./div[@class='info-primary']//a/span/text()").extract() working_place = node.xpath("./div[@class='info-primary']/p/text()").extract() company_name = node.xpath("./div[@class='info-company']//a/text()").extract() item['href'] = href[0] item['job_title'] = job_title[0] item['sa 报错: C:\Users\xieqianyun\AppData\Local\Programs\Python\Python36\python.exe "C:\Users\xieqianyun\PyCharm Community Edition 2019.2.5\helpers\pydev\pydevconsole.py" --mode=client --port=55825 import sys; print('Python %s on %s' % (sys.version, sys.platform)) sys.path.extend(['C:\\Users\\xieqianyun\\demo1', 'C:/Users/xieqianyun/demo1']) Python 3.6.5 (v3.6.5:f59c0932b4, Mar 28 2018, 17:00:18) [MSC v.1900 64 bit (AMD64)] Type 'copyright', 'credits' or 'license' for more information IPython 7.10.0 -- An enhanced Interactive Python. Type '?' for help. PyDev console: using IPython 7.10.0 Python 3.6.5 (v3.6.5:f59c0932b4, Mar 28 2018, 17:00:18) [MSC v.1900 64 bit (AMD64)] on win32 runfile('C:/Users/xieqianyun/demo1/demo1/begin.py', wdir='C:/Users/xieqianyun/demo1/demo1') Traceback (most recent call last): File "C:\Users\xieqianyun\AppData\Local\Programs\Python\Python36\lib\site-packages\IPython\core\interactiveshell.py", line 3319, in run_code exec(code_obj, self.user_global_ns, self.user_ns) File "<ipython-input-2-fc5979762143>", line 1, in <module> runfile('C:/Users/xieqianyun/demo1/demo1/begin.py', wdir='C:/Users/xieqianyun/demo1/demo1') File "C:\Users\xieqianyun\PyCharm Community Edition 2019.2.5\helpers\pydev\_pydev_bundle\pydev_umd.py", line 197, in runfile pydev_imports.execfile(filename, global_vars, local_vars) # execute the script File "C:\Users\xieqianyun\PyCharm Community Edition 2019.2.5\helpers\pydev\_pydev_imps\_pydev_execfile.py", line 18, in execfile exec(compile(contents+"\n", file, 'exec'), glob, loc) File "C:/Users/xieqianyun/demo1/demo1/begin.py", line 3, in <module> cmdline.execute('scrapy crawl demo'.split()) File "C:\Users\xieqianyun\AppData\Local\Programs\Python\Python36\lib\site-packages\scrapy\cmdline.py", line 145, in execute cmd.crawler_process = CrawlerProcess(settings) File "C:\Users\xieqianyun\AppData\Local\Programs\Python\Python36\lib\site-packages\scrapy\crawler.py", line 267, in __init__ super(CrawlerProcess, self).__init__(settings) File "C:\Users\xieqianyun\AppData\Local\Programs\Python\Python36\lib\site-packages\scrapy\crawler.py", line 145, in __init__ self.spider_loader = _get_spider_loader(settings) File "C:\Users\xieqianyun\AppData\Local\Programs\Python\Python36\lib\site-packages\scrapy\crawler.py", line 347, in _get_spider_loader return loader_cls.from_settings(settings.frozencopy()) File "C:\Users\xieqianyun\AppData\Local\Programs\Python\Python36\lib\site-packages\scrapy\spiderloader.py", line 61, in from_settings return cls(settings) File "C:\Users\xieqianyun\AppData\Local\Programs\Python\Python36\lib\site-packages\scrapy\spiderloader.py", line 25, in __init__ self._load_all_spiders() File "C:\Users\xieqianyun\AppData\Local\Programs\Python\Python36\lib\site-packages\scrapy\spiderloader.py", line 47, in _load_all_spiders for module in walk_modules(name): File "C:\Users\xieqianyun\AppData\Local\Programs\Python\Python36\lib\site-packages\scrapy\utils\misc.py", line 73, in walk_modules submod = import_module(fullpath) File "C:\Users\xieqianyun\AppData\Local\Programs\Python\Python36\lib\importlib\__init__.py", line 126, in import_module return _bootstrap._gcd_import(name[level:], package, level) File "<frozen importlib._bootstrap>", line 994, in _gcd_import File "<frozen importlib._bootstrap>", line 971, in _find_and_load File "<frozen importlib._bootstrap>", line 955, in _find_and_load_unlocked File "<frozen importlib._bootstrap>", line 665, in _load_unlocked File "<frozen importlib._bootstrap_external>", line 678, in exec_module File "<frozen importlib._bootstrap>", line 219, in _call_with_frames_removed File "C:\Users\xieqianyun\demo1\demo1\spiders\demo.py", line 4, in <module> from scrapy import log ImportError: cannot import name 'log'

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

刘丶小歪

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值