基于官方的验证脚本改的。
import json
import argparse
import time
import shutil
import pandas as pd
def __load_data(submit_file, reference_file, submit_dict, ref_dict):
# load submit result and reference result
with open(submit_file, 'r') as file1:
submit_data = json.load(file1)
with open(reference_file, 'r') as file1:
ref_data = json.load(file1)
if len(submit_data) != len(ref_data):
result['warning'].append('Inconsistent number of images between submission and reference data \n')
for item in submit_data:
submit_dict[item['image_id']] = item['label_id']
for item in ref_data:
ref_dict[item['image_id']] = int(item['label_id'])
return submit_dict, ref_dict
def __eval_result(submit_dict, ref_dict):
# eval accuracy
wrong_ids = []
correct_ids = []
right_count = 0
for (key, value) in ref_dict.items():
if key not in set(submit_dict.keys()):
result['warning'].append('lacking image %s in your submission file \n' % key)
print('warnning: lacking image %s in your submission file' % key)
continue
if value in submit_dict[key][:3]:
right_count += 1
if right_count<=100:
correct_ids.append(key)
else:
wrong_ids.append(key)
result['score'] = str(float(right_count)/max(len(ref_dict), 1e-5))
return result, wrong_ids, correct_ids
if __name__ == '__main__':
scene_classes = pd.read_csv('scene_classes.csv')
wrongs = 'validation_wrong'
corrects = 'validation_correct'
path = 'ai_challenger_scene_validation_20170908\\scene_validation_images_20170908' #图片目录
submit_dict = {}
ref_dict = {}
PARSER = argparse.ArgumentParser()
PARSER.add_argument(
'--submit',
type=str,
default='./submit.json',
help="""\
Path to submission file\
"""
)
PARSER.add_argument(
'--ref',
type=str,
default='./ref.json',
help="""\
Path to reference file\
"""
)
FLAGS = PARSER.parse_args()
result = {'error': [], 'warning': [], 'score': None}
START_TIME = time.time()
SUBMIT = {}
REF = {}
try:
SUBMIT, REF = __load_data(FLAGS.submit, FLAGS.ref, submit_dict, ref_dict)
except Exception as error:
result['error'].append(str(error))
try:
result, wrong_ids, correct_ids = __eval_result(SUBMIT, REF)
except Exception as error:
result['error'].append(str(error))
print('Evaluation time of your result: %f s' % (time.time() - START_TIME))
print(result)
for item in wrong_ids:
shutil.copyfile(path+'\\'+item, wrongs+'\\'+item.split('.')[0]+str(submit_dict[item])+str(ref_dict[item])+'.jpg')
for item in correct_ids:
shutil.copyfile(path+'\\'+item, corrects+'\\'+item.split('.')[0]+str(submit_dict[item])+str(ref_dict[item])+'.jpg')