【无标题】

# Copyright 2020 The TensorPilot Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
""" The evaluator for Waymo. """

import os, cv2
import numpy as np
import pickle, time
from tqdm import tqdm
import matplotlib.pylab as plt
import pdb

from tensorpilot.utils import Stopwatch # ???
from tensorpilot.engine import MPIEvaluator
from tensorpilot.models.detectors.bbox import BBoxXYWHCoder
from tensorpilot.datasets.waymo.common import WAYMO_CLASSES
from tensorpilot.visualizers import DetectionVisualizerWrapper


__all__ = [
    "WaymoEvaluator",
]


def plot_figure(ths, wsm, idsr, prec, rec, save_path=None):
  plt.figure(figsize=(16, 10))
  plt.plot(ths, idsr, '-.', label='IDSwR', linewidth=3, color='r', marker='o', 
          markerfacecolor='lime', markersize=10)
  for x, y in zip(ths, idsr):
    plt.text(x, y, f'{y:.3f}', ha='center', va='bottom', fontsize=15, color='r')
  
  plt.plot(ths, prec, '-.', label='prec', linewidth=3, color='g', marker='*', 
          markerfacecolor='lime', markersize=10)
  for x, y in zip(ths, prec):
    plt.text(x, y, f'{y:.3f}', ha='center', va='bottom', fontsize=15, color='g')
  
  plt.plot(ths, rec, '-.', label='rec', linewidth=3, color='b', marker='^', 
          markerfacecolor='lime', markersize=10)
  for x, y in zip(ths, rec):
    plt.text(x, y, f'{y:.3f}', ha='center', va='bottom', fontsize=15, color='b')
  
  plt.ylabel('rate')
  plt.xlabel('thresh')
  plt.xticks(np.arange(0.1, 1.1, 0.1))
  plt.vlines(wsm, 0, 1, label='fws', color='y', linestyles='dashed', linewidth=3)
  plt.text(wsm+0.005, 0, f'{wsm:.3f}', color='y', fontsize=15)
  plt.legend()
  plt.grid()
  if save_path is not None:
    plt.savefig(save_path)

def print_scores(mAP, cmc_scores, num, ws, p_str=''):
  if p_str:
    print('{:<10}:'.format(p_str), end='')
  print(('[mAP: {:5.2%}], [cmc1: {:5.2%}], [cmc5: {:5.2%}], ' 
         '[cmc10: {:5.2%}], [WS: {:.3f}], [num: {:d}]').format(
                          mAP, *cmc_scores[[0, 4, 9]], ws, num))

def reshape_outputs(inputs):
  clip_with_objs = {}
  for input in inputs:
    clip_id = input['ids'][0].decode('utf-8').split('_')[-2]
    # clip_id = input['ids'][0].decode('utf-8').split('_')[-2].split('+')[0]
    ids = input['tids']
    feats = input['outputs']['reid_embeddings']
    if not clip_id in clip_with_objs:
      clip_with_objs[clip_id] = {'ids': [], 'feats': [], 'full_ids': []}
    clip_with_objs[clip_id]['ids'].append(ids)
    clip_with_objs[clip_id]['feats'].append(feats)
    clip_with_objs[clip_id]['full_ids'].append(input['ids'])
  
  for key, val in clip_with_objs.items():
    clip_with_objs[key] = {'ids': np.hstack(val['ids']), 
                           'feats': np.vstack(val['feats']), 
                           'full_ids': np.hstack(val['full_ids'])}
  return clip_with_objs

def compute_mAP(index, good_index):
    ap = 0
    cmc = np.zeros(max(len(index), 10))
    if good_index.size==0:
        cmc[0] = -1
        return ap, cmc
    # find good_index index
    ngood = len(good_index)
    mask = np.in1d(index, good_index)
    rows_good = np.where(mask)[0]
    cmc[rows_good[0]:] = 1
    d_recall = 1.0/ngood
    precision = (np.arange(len(rows_good), dtype=np.float) + 1) / (rows_good + 1)
    if rows_good[0] == 0:
        old_precision = np.ones(len(rows_good))
        old_precision[1:] = np.arange(1, len(rows_good), dtype=np.float) / rows_good[1:]
    else:
        old_precision = np.arange(len(rows_good), dtype=np.float) / rows_good
    ap = np.sum((precision + old_precision) / 2. * d_recall)
    return ap, cmc

def compute_idswitch(dist, ids, index, cur_id):
  thresh = np.arange(0.95, 0.05, -0.05)
  sims = 1 - dist[index] / 2.
  ids  = ids[index]
  right = (ids == cur_id)
  total_num = right.sum()
  idsc = np.zeros(len(thresh))
  cnt_num   = (thresh[:,None] <= sims[None, :]).sum(1)
  right_num = ((thresh[:,None] <= sims[None, :]) & right[None,:]).sum(1)
  recall = right_num / total_num
  right_num[cnt_num == 0] = 1
  cnt_num[cnt_num == 0] = 1 
  precision = right_num / cnt_num
  first_wrong_sim = sims[~right][0]
  idsc[thresh < first_wrong_sim] = 1
  f1 = (precision + recall) / 2.
  return first_wrong_sim, idsc, precision, recall, f1

def compute_mAP_one_clip(ids, feats):
  badcase = {}
  dists = 2 - 2 * np.matmul(feats, feats.T)
  thresh = np.arange(0.95, 0.05, -0.05)
  idsrs = np.zeros(len(thresh))
  f1s = np.zeros(len(thresh))
  precs = np.zeros(len(thresh))
  recs  = np.zeros(len(thresh))
  num = len(ids)
  cmc = np.zeros(max(num - 1, 10))
  wrong_sims = np.zeros(num)
  ap = 0.
  iters = range(num)
  if num > 1000:
    iters = tqdm(iters)
  valid_cnt = 0

  for i in iters:
    cur_dist = dists[i]
    cur_dist = np.delete(cur_dist, i)
    cur_ids = ids[:]
    cur_ids = np.delete(cur_ids, i)
    index = cur_dist.argsort()
    good_index = np.where(cur_ids == ids[i])[0]
    ap_tmp, cmc_tmp = compute_mAP(index, good_index)
    if ap_tmp !=0 and ap_tmp < 0.2:
      # import pdb; pdb.set_trace()
      badcase[f"{index}"] = [good_index]
    ap += ap_tmp
    if cmc_tmp[0] == -1:
      continue
    cmc += cmc_tmp
    valid_cnt += 1
    wrong_sim, isr, prec, rec, f1 = compute_idswitch(cur_dist, cur_ids, index, ids[i])
    wrong_sims[i] = wrong_sim
    idsrs += isr
    f1s += f1
    precs += prec
    recs  += rec
  if badcase != {}:    
    print(badcase)
  valid_cnt_div = valid_cnt if valid_cnt > 0 else 1
  cmc = cmc[:10] / valid_cnt_div
  ap = ap / valid_cnt_div
  idsrs /= valid_cnt_div
  f1s /= valid_cnt_div
  precs /= valid_cnt_div
  recs  /= valid_cnt_div
  return ap, cmc, valid_cnt, wrong_sims, idsrs, precs, recs, f1s

class WaymoEvaluator(MPIEvaluator):
  """ The evaluator for Waymo tasks. """

  def __init__(self, cfg, model, task, gather=True, pickup_first_batch=False,
               print_detail=True, cache_feat_dir=None):
    """
    Args:
      cfg (TaskConfig): the configuration of task.
      model (BaseModel): the model of task.
      task (str): the task of the waymo, currently only supports the
        ['camera_detection', 'camera_tracking']
      gather (bool): whether gather the results to device 0.
      coco_eval_json (str): the coco-style json for evaluation.
    """
    super().__init__(cfg, model, gather, pickup_first_batch=pickup_first_batch)
    self.task = task
    self.print_detail = print_detail
    self.cache_feat_dir = cache_feat_dir

  def format_results(self, results):
    """ format the results into pretty printable strings.

    Args:
      results: dict of values.
    """
    lines = []
    for key, value in results.items():
      # TODO: add more dtypes here.
      if isinstance(value, (float)):
        lines.append(f'{key} = {value:.4f}')
      elif isinstance(value, np.ndarray):
        lines.append(f"{key} = {np.array2string(value, precision=4, separator=', ')}")
      else:
        lines.append(f'{key} = {value}')
    contents = '\n'.join(lines)
    return contents


  def evaluate(self, model_path, restore=True):
    if self.cache_feat_dir:
      if model_path is not None:
        save_name = os.path.basename(model_path).replace('ckpt-', 'ep_') + '_feat.pkl'
        fpath = os.path.join(self.cache_feat_dir, save_name)
        if os.path.exists(fpath):
          with open(fpath, 'rb') as f:
            outputs_list = pickle.load(f)
          results = None
          stop_watch = Stopwatch(True)
          if self.cfg.is_root or not self.gather:
            results = self.post_process(outputs_list, model_path)

          if self.cfg.is_root:
            stop_watch.acc()
            eval_cost = stop_watch.duration
            date = time.strftime('%Y-%m-%d  %H:%M:%S', time.gmtime())
            if isinstance(results, (tuple, list)):
              results, str_results = results
            elif isinstance(results, dict):
              str_results = self.format_results(results)
            else:
              str_results = str(results)
            msg = f'Eval {model_path}\nDate {date}\n{str_results}\n{eval_cost:.2f} s\n'
            with open(self.cfg.eval_file, 'a') as wf:
              wf.write(msg)
            self.logger.info(msg)
          return results

    return super().evaluate(model_path, restore=restore)

  def post_process(self, outputs_list, model_path):
    """ Compute the NIO-Evaluation results.

    Args:
      outputs_list (list): the outputs from the model.
      model_path (str): the evaluated model_path, used to save the results.

    Returns:
      nio metrics.
    """
    if self.cache_feat_dir:
      if model_path is not None:
        save_name = os.path.basename(model_path).replace('ckpt-', 'ep_') + '_feat.pkl'
        if not os.path.exists(self.cache_feat_dir):
          os.makedirs(self.cache_feat_dir)
        fpath = os.path.join(self.cache_feat_dir, save_name)
        if not os.path.exists(fpath):
          with open(fpath, 'wb') as f:
            pickle.dump(outputs_list, f)

    outputs = reshape_outputs(outputs_list)
    total_ap = 0.
    total_cmc = np.zeros(10)
    total_num = 0
    WS = np.zeros(0)
    thresh = np.arange(0.95, 0.05, -0.05)
    total_f1s = np.zeros(len(thresh))
    total_precs = np.zeros(len(thresh))
    total_recs  = np.zeros(len(thresh))
    total_idsrs = np.zeros(len(thresh))
    cnt = 0
    step = (len(outputs) + 9) // 10
    badcase = []
    for clip_id, val in outputs.items():
      if cnt % step == 0:
        print(f'{cnt}|{len(outputs)} processing ...')
      ap, cmc, num, ws, idsrs, precs, recs, f1s = compute_mAP_one_clip(val['ids'], val['feats'])
      if self.print_detail:
        print_scores(ap, cmc, num, ws.mean(), f'clip-{clip_id}')
      if ap < 0.3:
            badcase.append(val)
      total_ap += ap * num
      total_cmc += cmc * num
      total_num += num
      WS = np.hstack([WS, ws])
      total_idsrs += idsrs * num
      total_f1s += f1s * num
      total_precs += precs * num
      total_recs  += recs * num
      cnt += 1
    print(badcase)
    mAP = total_ap / total_num
    CMC = total_cmc / total_num
    wsm = WS.mean()
    total_idsrs /= total_num
    total_f1s /= total_num
    total_precs /= total_num
    total_recs  /= total_num
    print_scores(mAP, CMC, total_num, wsm, f"total: ")
    results = {'mAP': mAP, 'CMC1': CMC[0], 'CMC5': CMC[4], 
               'wsm': wsm, 'thresh': thresh, 'IDSwR': total_idsrs, 
               'prec': total_precs, 'rec': total_recs, 'f1': total_f1s}
    eval_dataset, eval_mode = self.data_provider.eval_dataset, self.data_provider.eval_mode
    prefix = f"{eval_dataset}_{eval_mode}_" if eval_dataset else ""
    fig_name = prefix + 'vis_' + os.path.basename(model_path).replace('ckpt-', 'ep') + '.jpg'
    fig_save_path = os.path.join(self.cfg.eval_dir, fig_name)
    plot_figure(thresh, wsm, total_idsrs, total_precs, total_recs, fig_save_path)
    str_results = self.format_results(results)
    res_name = prefix + 'res_' + os.path.basename(model_path).replace('ckpt-', 'ep') + '.txt'
    res_save_path = os.path.join(self.cfg.eval_dir, res_name)
    with open(res_save_path, 'w') as wf:
      wf.write(f'Eval {model_path}\n{str_results}\n')
    return results

  def visualize(self, data_dict, vis_dir):
    DetectionVisualizerWrapper()(data_dict, WAYMO_CLASSES, vis_dir)


if __name__ == "__main__":
  with open('/home/zhixing.cheng/reid/object-reid/reid/zhixing.cheng/reid_bs_PCBv3_reid/reid_20test_caches/ep_24_feat.pkl', 'rb') as f:
    data = pickle.load(f)
  outputs = reshape_outputs(data)
  total_ap = 0.
  total_cmc = np.zeros(10)
  total_num = 0
  WS = np.zeros(0)
  thresh = np.arange(0.95, 0.05, -0.05)
  total_f1s = np.zeros(len(thresh))
  total_precs = np.zeros(len(thresh))
  total_recs  = np.zeros(len(thresh))
  total_idsrs = np.zeros(len(thresh))
  cnt = 0
  step = (len(outputs) + 9) // 10
  for clip_id, val in outputs.items():
    badcase_all = []
    if cnt % step == 0:
      print(f'{cnt}|{len(outputs)} processing ...')
    ap, cmc, num, ws, idsrs, precs, recs, f1s = compute_mAP_one_clip(val['ids'], val['feats'])
    total_ap += ap * num
    total_cmc += cmc * num
    total_num += num
    WS = np.hstack([WS, ws])
    total_idsrs += idsrs * num
    total_f1s += f1s * num
    total_precs += precs * num
    total_recs  += recs * num
    cnt += 1
  # print(badcase_all)
  mAP = total_ap / total_num
  CMC = total_cmc / total_num
  wsm = WS.mean()
  total_idsrs /= total_num
  total_f1s /= total_num
  total_precs /= total_num
  total_recs  /= total_num
  # print_scores(mAP, CMC, total_num, wsm, f"total: ")
  results = {'mAP': mAP, 'CMC1': CMC[0], 'CMC5': CMC[4], 
              'wsm': wsm, 'thresh': thresh, 'IDSwR': total_idsrs, 
              'prec': total_precs, 'rec': total_recs, 'f1': total_f1s}
  fig_save_path = 'vis_waymo.jpg'
  # plot_figure(thresh, wsm, total_idsrs, total_precs, total_recs, fig_save_path)
  # pdb.set_trace()

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值