从小函数开始 代码
1 test_net()
def test_net(
args,
dataset_name,
proposal_file,
output_dir,
ind_range=None,
gpu_id=0):
"""Run inference on all images in a dataset or over an index range of images
in a dataset using a single GPU.
"""
在一个数据集指定的图片或所有图片上进行预测,返回 all_boxes, all_segms, all_keyps
其中all_boxes的格式值得关注:
Box detections are collected into:
all_boxes[cls][image] = N x 5 array with columns (x1, y1, x2, y2, score)
这个函数的一个附加功能是 将预测结果直接保存到.pkl文件了
if ind_range is not None:
det_name = 'detection_range_%s_%s.pkl' % tuple(ind_range)
else:
det_name = 'detections.pkl'
det_file = os.path.join(output_dir, det_name)
save_object(
dict(
all_boxes=all_boxes,
all_segms=all_segms,
all_keyps=all_keyps,
cfg=cfg_yaml
), det_file
)
logger.info('Wrote detections to: {}'.format(os.path.abspath(det_file))) # 将检测结果保存下来
return all_boxes, all_segms, all_keyps
2 test_net_on_dataset()"""Run inference on a dataset."""
如果是多个GPU预测,那么调用 multi_gpu_test_net_on_dataset
单个GPU预测,那么调用 test_net()
最后会返回results , map recall等(这个有待进一步确定)
results = task_evaluation.evaluate_all( # 对预测结果进行评估 (挺重要的)
dataset, all_boxes, all_segms, all_keyps, output_dir
)
3 run_inference()
调用 了 test_net_on_dataset,最后将结果综合起来,输出 map recall那个界面
# Copyright (c) 2017-present, Facebook, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##############################################################################
"""Test a Detectron network on an imdb (image database)."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from collections import defaultdict
import cv2
import datetime
import logging
import numpy as np
import os
import yaml
import torch
from core.config import cfg
# from core.rpn_generator import generate_rpn_on_dataset #TODO: for rpn only case
# from core.rpn_generator import generate_rpn_on_range
from core.test import im_detect_all
from datasets import task_evaluation
from datasets.json_dataset import JsonDataset
from modeling import model_builder
import nn as mynn
from utils.detectron_weight_helper import load_detectron_weight
import utils.env as envu
import utils.net as net_utils
import utils.subprocess as subprocess_utils
import utils.vis as vis_utils
from utils.io import save_