HPatches数据集(图像匹配)---2关于评估代码的解释---和python画出结果

关于画图:

参考:

Matplotlib系列:

https://blog.csdn.net/yuyh131/category_7823048.html

 

关于评估代码的解释:

我们先提前下载所有算法对数据集patches提取的描述符:

 ./download.sh descr
List of available descriptor results file for HPatches:
+-------------------+---------------------------------------------------------------------+
|       name        |                             description                             |
+-------------------+---------------------------------------------------------------------+
| ncc               | Normalised cross correlation                                        |
| sift              | SIFT [Lowe IJCV 2004]                                               |
| rootsift          | rootSIFT [Arandjelović & Zisserman CVPR 2012]                       |
| orb               | ORB [Rublee et al ICCV 2011]                                        |
| brief             | BRIEF [Calonder et al. PAMI 2012]                                   |
| binboost          | BinBoost [Trzcinski et al. PAMI 2013]                               |
| deepdesc          | DeepDesc [Simo-Serra et al. ICCV 2015]                              |
| liop              | LIOP [Wang et al ICCV 2011]                                         |
| tfeat-margin-star | TFeat with margin loss [Balntas et al. BMVC 2016]                   |
| tfeat-ratio-star  | TFeat with ratio loss [Balntas et al. BMVC 2016]                    |
| dc-siam           | DeepCompare siamese [Zagoruyko & Komodakis CVPR 2015]               |
| dc-siam2stream    | DeepCompare siamese 2-stream [Zagoruyko & Komodakis CVPR 2015]      |
| hardnet           | HardNet [Mishchuk et al. NIPS 2017]                                 |
| hardnet+          | HardNet with training data augmentation [Mishchuk et al. NIPS 2017] |
+-------------------+---------------------------------------------------------------------+

比如ORB:

 

----------------

1.可视化数据集:

python hpatches_vis.py

from utils.hpatch import *
import cv2
import os.path


# all types of patches
tps = ['ref','e1','e3','e5','h1','h3','h5','t1','t3','t5']
datadir = './ImageMatch_dataset/data'
#datadir = os.path.normpath(os.path.join(os.path.dirname(__file__), "..", "data"))

def vis_patches(seq,tp,ids):
    """Visualises a set of types and indices for a sequence"""
    h = len(tp)*65
    vis = np.empty((h, 0))
    # add the first column with the patch type names
    vis_tmp = np.empty((0,55))
    for t in tp:
        tp_patch = 255*np.ones((65,55))
        cv2.putText(tp_patch,t,(5,25),cv2.FONT_HERSHEY_DUPLEX , 1,0,1)
        vis_tmp = np.vstack((vis_tmp,tp_patch))
    vis = np.hstack((vis,vis_tmp))
    # add the actual patches
    for idx in ids:
        vis_tmp = np.empty((0,65))
        for t in tp:
            vis_tmp = np.vstack((vis_tmp,get_patch(seq,t,idx)))
        vis = np.hstack((vis,vis_tmp))
    return vis


# select a subset of types of patches to visualise
# tp = ['ref','e5','h5','t5']
#or visualise all - tps holds all possible types
tp = tps

# list of patch indices to visualise
ids = range(1,55)

# load a sample sequence
seq = hpatch_sequence(os.path.join(datadir, "hpatches-release", "v_calder"))
vis = vis_patches(seq,tp,ids)

# show
cv2.imshow("HPatches example", vis/255)
cv2.waitKey(0)

# or save
cv2.imwrite("patches.png", vis)

调用了函数:

class hpatch_sequence:
    """Class for loading an HPatches sequence from a sequence folder"""
    itr = tps
    def __init__(self,base):
        name = base.split(os.path.sep)
        self.name = name[-1]
        self.base = base
        for t in self.itr:
            im_path = os.path.join(base, t+'.png')
            im = cv2.imread(im_path,0)
            self.N = im.shape[0]/65
            setattr(self, t, np.split(im, self.N))
            '''
            print np.split(A, 3, axis = 0)
            
            [[ 0  1  2  3]
            [ 4  5  6  7]
            [ 8  9 10 11]]

            [array([[0, 1, 2, 3]]), array([[4, 5, 6, 7]]), array([[ 8,  9, 10, 11]])]
            '''


def get_patch(seq,t,idx):
    """Gets a patch from a sequence with type=t and id=idx"""
    return getattr(seq, t)[idx]

2.导出某个已知算法的patches---例如SIFT

import sys
import argparse
import time
import os
import sys
import cv2
import math
import numpy as np
from tqdm import tqdm
from copy import deepcopy
import random
import time
import numpy as np
import glob
import os

#assert len(sys.argv)==3, "Usage python extract_opencv_sift.py hpatches_db_root_folder 65"
OUT_W = 65
#OUT_W = int(sys.argv[2])
# all types of patches 
tps = ['ref','e1','e2','e3','e4','e5','h1','h2','h3','h4','h5',\
       't1','t2','t3','t4','t5']

#获取hpatches序列
class hpatches_sequence:
    """Class for loading an HPatches sequence from a sequence folder"""
    itr = tps
    def __init__(self,base):
        name = base.split('/')
        self.name = name[-1]
        self.base = base
        for t in self.itr:
            im_path = os.path.join(base, t+'.png')
            im = cv2.imread(im_path,0)
            self.N = im.shape[0]/65
            setattr(self, t, np.split(im, self.N))


#seqs = glob.glob(sys.argv[1]+'/*')
seqs = glob.glob("./ImageMatch_dataset/data/hpatches-release"+'/*')
seqs = [os.path.abspath(p) for p in seqs]

descr_name = 'opencv-sift-'+str(OUT_W)
sift1 = cv2.xfeatures2d.SIFT_create()

#在patches图(65*65)上中心处建立特征点
def get_center_kp(PS=65.):
    c = PS/2.0
    center_kp = cv2.KeyPoint()
    center_kp.pt = (c,c)
    #size:该点直径的大小
    center_kp.size = 2*c/5.303
    return center_kp

ckp = get_center_kp(OUT_W)

#遍历全部数据集
for seq_path in seqs:
    seq = hpatches_sequence(seq_path)
    path = os.path.join(descr_name,seq.name)
    if not os.path.exists(path):
        os.makedirs(path)
    descr = np.zeros((int(seq.N),int(128))) # trivial (mi,sigma) descriptor
    #遍历不同几何噪声增强的patches序列
    for tp in tps:
        print(seq.name+'/'+tp)
        if os.path.isfile(os.path.join(path,tp+'.csv')):
            continue
        n_patches = 0
        for i,patch in enumerate(getattr(seq, tp)):
            n_patches+=1
        t = time.time()
        patches_resized = np.zeros((n_patches, 1, OUT_W, OUT_W)).astype(np.uint8)
        #判断是否为65*65
        if OUT_W != 65:
            for i,patch in enumerate(getattr(seq, tp)):
                patches_resized[i,0,:,:] = cv2.resize(patch,(OUT_W,OUT_W))
        else:
            for i,patch in enumerate(getattr(seq, tp)):
                patches_resized[i,0,:,:] = patch
        outs = []
        #为了取整,不用使用int()
        bs = 1;
        n_batches = n_patches / bs
        #遍历patches序列
        for batch_idx in range(int(n_batches)):
            if batch_idx == n_batches - 1:
                if (batch_idx + 1) * bs > n_patches:
                    end = n_patches
                else:
                    end = (batch_idx + 1) * bs
            else:
                end = (batch_idx + 1) * bs
            data_a = patches_resized[batch_idx * bs: end, :, :, :]
            #通过前面初始的一个keypoint来计算他的描述符
            outs.append(sift1.compute(data_a[0,0],[ckp])[1][0].reshape(-1, 128))
        res_desc = np.concatenate(outs)
        res_desc = np.reshape(res_desc, (n_patches, -1))
        out = np.reshape(res_desc, (n_patches,-1))
        np.savetxt(os.path.join(path,tp+'.csv'), out, delimiter=';', fmt='%d')
        #np.savetxt(os.path.join(path,tp+'.csv'), out, delimiter=',', fmt='%10.5f')

3.评估描述符----ML(DL)描述子 or 人工描述子

1)保存结果

Evaluation code for the HPatches homography patches dataset.

Usage:
  hpatches_eval.py (-h | --help)
  hpatches_eval.py --version
  hpatches_eval.py --descr-name=<> --task=<>... [--descr-dir=<>] [--results-dir=<>] [--split=<>] [--dist=<>] [--delimiter=<>] [--pcapl=<>]

Options:
  -h --help         Show this screen.
  --version         Show version.
  --descr-name=<>   Descriptor name, e.g. sift
  --descr-dir=<>    Descriptor results root folder. [default: {root}/data/descriptors]
  --results-dir=<>  Results root folder. [default: results]
  --task=<>         Task name. Valid tasks are {verification,matching,retrieval}.
  --split=<>        Split name. Valid are {a,b,c,full,illum,view}. [default: a]
  --dist=<>         Distance name. Valid are {L1,L2}. [default: L2]
  --delimiter=<>    Delimiter used in the csv files. [default: ,]
  --pcapl=<>        Compute results for pca-power law descr. [default: no]

For more visit: https://github.com/hpatches/

python hpatches_eval.py --descr-name=sift  --task=verification  --descr-dir=./ImageMatch_dataset/data/descriptors  --results-dir=./results  --split=a  --delimiter=";"  --dist=L2  --pcapl=no

python hpatches_eval.py --descr-name="sift"  --task="verification"  --descr-dir="./ImageMatch_dataset/data/descriptors"  --results-dir="./results"  --split="a"  --delimiter=";"  --dist="L2"  --pcapl="no"

>> Running HPatch evaluation for sift
>> Please wait, loading the descriptor files...
>> Descriptor files loaded.
>> Evaluating verification task
./ImageMatch_dataset/code/hpatches-benchmark-master/python/utils/tasks.py:99: FutureWarning: Method .as_matrix will be removed in a future version. Use .values instead.
  pos = pd.read_csv(os.path.join(tskdir, 'verif_pos_split-'+split['name']+'.csv')).as_matrix()
./ImageMatch_dataset/code/hpatches-benchmark-master/python/utils/tasks.py:100: FutureWarning: Method .as_matrix will be removed in a future version. Use .values instead.
  neg_intra = pd.read_csv(os.path.join(tskdir, 'verif_neg_intra_split-'+split['name']+'.csv')).as_matrix()
./ImageMatch_dataset/code/hpatches-benchmark-master/python/utils/tasks.py:101: FutureWarning: Method .as_matrix will be removed in a future version. Use .values instead.
  neg_inter = pd.read_csv(os.path.join(tskdir, 'verif_neg_inter_split-'+split['name']+'.csv')).as_matrix()
Processing verification task 1/3 : 100%|████████████████████████████████████████████████████████████████████████████████| 1000000/1000000 [00:59<00:00, 16697.15it/s]
Processing verification task 2/3 : 100%|████████████████████████████████████████████████████████████████████████████████| 1000000/1000000 [01:16<00:00, 13040.02it/s]
Processing verification task 3/3 : 100%|████████████████████████████████████████████████████████████████████████████████| 1000000/1000000 [01:16<00:00, 13028.28it/s]
>> Verification task finished in 225 secs

2)打印结果

Usage:
  hpatches_results.py (-h | --help)
  hpatches_results.py --version
  hpatches_results.py --descr-name=<>... --task=<>... [--results-dir=<>] [--split=<>] [--pcapl=<>]

Options:
  -h --help         Show this screen.
  --version         Show version.
  --descr-name=<>   Descriptor name e.g. --descr=sift.
  --results-dir=<>  Results root folder. [default: results]
  --task=<>         Task name. Valid tasks are {verification,matching,retrieval}.
  --split=<>        Split name. Valid are {a,b,c,full,illum,view}. [default: a]
  --pcapl=<>        Show results for pca-power law descr. [default: no]

For more visit: https://github.com/hpatches/

python hpatches_results.py --descr=sift --results-dir=./results --task=verification  --split=a  --pcapl=no

python hpatches_results.py --descr="sift"   --results-dir="./results"  --task="verification"  --split="a"  --pcapl="no"

Verification task results:
SIFT - Balanced variant (auc) 
Noise       Inter     Intra
-------  --------  --------
Easy     0.930915  0.909937
Hard     0.823407  0.795857
Tough    0.730735  0.705034
SIFT - Imbalanced variant (ap) 
Noise       Inter     Intra
-------  --------  --------
Easy     0.84945   0.783138
Hard     0.656835  0.569733
Tough    0.512454  0.42948

 

批量评估程序:加了个for循环:

修改自hpatches_eval.py

# -*- coding: utf-8 -*-
"""Evaluation code for the HPatches homography patches dataset.

Usage:
  hpatches_eval.py (-h | --help)
  hpatches_eval.py --version
  hpatches_eval.py --descr-name=<> --task=<>... [--descr-dir=<>] [--results-dir=<>] [--split=<>] [--dist=<>] [--delimiter=<>] [--pcapl=<>]

Options:
  -h --help         Show this screen.
  --version         Show version.
  --descr-name=<>   Descriptor name, e.g. sift
  --descr-dir=<>    Descriptor results root folder. [default: {root}/data/descriptors]
  --results-dir=<>  Results root folder. [default: results]
  --task=<>         Task name. Valid tasks are {verification,matching,retrieval}.
  --split=<>        Split name. Valid are {a,b,c,full,illum,view}. [default: a]
  --dist=<>         Distance name. Valid are {L1,L2}. [default: L2]
  --delimiter=<>    Delimiter used in the csv files. [default: ,]
  --pcapl=<>        Compute results for pca-power law descr. [default: no]

For more visit: https://github.com/hpatches/
"""
from utils.hpatch import *
from utils.tasks import *
from utils.misc import *
from utils.docopt import docopt
import os
import time
import dill

#三个任务:verification, matching, retrieval
#python hpatches_eval.py --descr-name=sift --task=verification --delimiter=";"
if __name__ == '__main__':
    opts = docopt(__doc__, version='HPatches 1.0')
    descr_dir = opts['--descr-dir'].format(
        root=os.path.normpath(os.path.join(os.path.abspath(os.path.dirname(__file__)), ".."))
    )
    descr_dir = "/home/boyun/deepglint/ImageMatch_dataset/data/descriptors"
    path = os.path.join(descr_dir, opts['--descr-name'])

    lines = os.listdir(descr_dir)
    for line in lines:

        opts['--descr-name'] = line
        if line == "sift" or line == "rootsift":
            opts['--delimiter'] = ";"
        else:
            opts['--delimiter'] = ","

        path = os.path.join(descr_dir, line)

        try:
            assert os.path.exists(path)
        except:
           print("%r does not exist." % (path))
           exit(0)

        results_dir = opts['--results-dir']
        if not os.path.exists(results_dir):
            os.makedirs(results_dir)

        descr_name = opts['--descr-name']
        print('\n>> Running HPatch evaluation for %s' % blue(descr_name))

        descr = load_descrs(path,dist=opts['--dist'],sep=opts['--delimiter'])

        with open(os.path.join(tskdir, "splits", "splits.json")) as f:
            splits = json.load(f)

        splt = splits[opts['--split']]

        for t in opts['--task']:
            res_path = os.path.join(results_dir, descr_name+"_"+t+"_"+splt['name']+".p")
            if os.path.exists(res_path):
                print("Results for the %s, %s task, split %s, already cached!" %\
                      (descr_name,t,splt['name']))
            else:
                res = methods[t](descr,splt)
                dill.dump(res, open(res_path, "wb"))

        # do the PCA/power-law evaluation if wanted
        if opts['--pcapl']!='no':
            print('>> Running evaluation for %s normalisation' % blue("pca/power-law"))
            compute_pcapl(descr,splt)
            for t in opts['--task']:
                res_path = os.path.join(results_dir, descr_name+"_pcapl_"+t+"_"+splt['name']+".p")
                if os.path.exists(res_path):
                    print("Results for the %s, %s task, split %s,PCA/PL already cached!" %\
                          (descr_name,t,splt['name']))
                else:
                    res = methods[t](descr,splt)
                    dill.dump(res, open(res_path, "wb"))

评估代码详细:

为了更好的测试和评估,将HPatches分割成train,test分组or直接纯测试组

作者提供了几种分割方案:

根据以上表,我们按照原始给出的json进行罗列,该分割方案都有哪些场景序列

分割方案:a

"test": 
"i_ajuntament", "i_resort", "i_table", "i_troulos", "i_bologna", "i_lionnight", "i_porta", "i_zion", "i_brooklyn",
 "i_fruits", "i_books", "i_bridger", "i_whitebuilding", "i_kurhaus", "i_salon", "i_autannes", "i_tools", 
"i_santuario", "i_fog", "i_nijmegen", "v_courses", "v_coffeehouse", "v_abstract", "v_feast", 
"v_woman", "v_talent", "v_tabletop", "v_bees", "v_strand", "v_fest", "v_yard", "v_underground",
 "v_azzola", "v_eastsouth", "v_yuri", "v_soldiers", "v_man", "v_pomegranate", "v_birdwoman", "v_busstop"

"train": 
"v_there", "i_yellowtent", "i_boutique", "v_wapping", "i_leuven", "i_school", "i_crownnight", "v_artisans", 
"v_colors", "i_ski", "v_circus", "v_tempera", "v_london", "v_war", "i_parking", "v_bark", "v_charing",
 "i_indiana", "v_weapons", "v_wormhole", "v_maskedman", "v_dirtywall", "v_wall", "v_vitro", "i_nuts",
 "i_londonbridge", "i_pool", "i_pinard", "i_greentea", "v_calder", "i_lionday", "i_crownday", "i_kions",
 "v_posters", "i_dome", "v_machines", "v_laptop", "v_boat", "v_churchill", "i_pencils", "v_beyus", 
"v_sunseason", "v_samples", "v_cartooncity", "v_gardens", "v_bip", "v_home", "i_veggies", "i_nescafe",
 "v_wounded", "i_toy", "v_dogman", "i_duda", "i_contruction", "v_graffiti", "i_gonnenberg", "v_astronautis",
 "i_ktirio", "i_castle", "i_greenhouse", "i_fenis", "i_partyfood", "v_adam", "v_apprentices", "v_blueprint",
 "i_smurf", "i_objects", "v_bird", "i_melon", "v_grace", "i_miniature", "v_bricks", "i_chestnuts", 
"i_village", "i_steps", "i_dc"

分割方案:b

"test": 
"i_fruits", "i_melon", "i_castle", "i_resort", "i_chestnuts", "i_kions", "i_kurhaus", "i_autannes", 
"i_duda", "i_partyfood", "i_ski", "i_dome", "i_greenhouse", "i_pencils", "i_porta", "i_lionday", 
"i_school", "i_bridger", "i_village", "i_fog", "v_astronautis", "v_bip", "v_charing", "v_woman", 
"v_feast", "v_yard", "v_churchill", "v_graffiti", "v_london", "v_sunseason", "v_posters", "v_bees", 
"v_apprentices", "v_birdwoman", "v_colors", "v_laptop", "v_there", "v_adam", "v_underground", "v_war"

"train": 
"v_wormhole", "i_yellowtent", "i_boutique", "v_wapping", "i_leuven", "i_pinard", "i_crownnight", 
"v_artisans", "i_toy", "v_circus", "v_tempera", "i_lionnight", "i_parking", "v_soldiers", 
"i_contruction", "i_whitebuilding", "i_indiana", "v_azzola", "v_weapons", "v_fest", "v_yuri", 
"v_dirtywall", "v_eastsouth", "v_man", "v_wall", "v_vitro", "i_nuts", "i_londonbridge", "i_pool", 
"v_tabletop", "i_greentea", "i_brooklyn", "i_smurf", "v_cartooncity", "v_wounded", "v_calder", 
"v_coffeehouse", "v_grace", "v_machines", "i_tools", "v_boat", "v_beyus", "v_strand", "i_santuario", 
"i_crownday", "v_bark", "i_veggies", "i_nescafe", "v_maskedman", "v_abstract", "v_talent", "i_books", 
"i_table", "v_courses", "i_nijmegen", "i_salon", "i_gonnenberg", "v_samples", "i_ktirio", "v_gardens", 
"i_zion", "v_pomegranate", "i_fenis", "v_home", "i_ajuntament", "i_objects", "v_bird", "v_dogman", 
"i_troulos", "i_miniature", "v_bricks", "i_bologna", "v_busstop", "v_blueprint", "i_steps", "i_dc"

分割方案:c

"test": 
"i_ski", "i_table", "i_troulos", "i_melon", "i_tools", "i_kions", "i_londonbridge", "i_nijmegen", 
"i_boutique", "i_parking", "i_steps", "i_fog", "i_leuven", "i_dc", "i_partyfood", "i_pool", "i_castle", 
"i_bologna", "i_smurf", "i_crownnight", "v_azzola", "v_tempera", "v_machines", "v_coffeehouse", 
"v_graffiti", "v_artisans", "v_maskedman", "v_talent", "v_bees", "v_dirtywall", "v_blueprint", 
"v_war", "v_adam", "v_pomegranate", "v_busstop", "v_weapons", "v_gardens", "v_feast", "v_man", 
"v_wounded"
		
"train":
"v_there", "i_yellowtent", "i_whitebuilding", "v_wapping", "v_laptop", "i_school", "v_calder", "i_duda",
 "v_circus", "i_porta", "v_home", "i_lionnight", "i_chestnuts", "v_abstract", "v_soldiers", "i_contruction",
 "v_charing", "i_indiana", "v_strand", "v_fest", "v_yuri", "v_wormhole", "v_eastsouth", "i_autannes", 
"v_colors", "v_wall", "v_vitro", "i_nuts", "i_pinard", "v_tabletop", "i_brooklyn", "i_lionday", 
"i_crownday", "v_bip", "v_posters", "v_underground", "i_dome", "v_grace", "i_ajuntament", "v_cartooncity",
 "v_boat", "v_churchill", "i_pencils", "v_beyus", "v_sunseason", "v_samples", "i_kurhaus", "i_santuario",
 "i_resort", "i_zion", "i_veggies", "i_nescafe", "i_toy", "v_dogman", "i_books", "v_courses", "v_birdwoman", 
"v_yard", "i_salon", "i_gonnenberg", "v_astronautis", "i_ktirio", "i_bridger", "i_greenhouse", "i_fenis", 
"v_woman", "v_bricks", "v_apprentices", "i_greentea", "i_objects", "v_bird", "v_london", "i_fruits", 
"i_miniature", "i_village", "v_bark"

分割方案:illum

"test":
"i_crownnight", "i_table", "i_objects", "i_nescafe", "i_nijmegen", "i_whitebuilding", "i_porta", 
"i_santuario", "i_dc", "i_castle", "i_steps", "i_contruction", "i_melon", "i_yellowtent", 
"i_miniature", "i_troulos", "i_veggies", "i_zion", "i_gonnenberg", "i_autannes", "i_boutique",
 "i_fruits", "i_pool", "i_fog", "i_fenis", "i_village", "i_ajuntament", "i_partyfood", "i_kurhaus", 
"i_school", "i_chestnuts", "i_smurf", "i_indiana", "i_pinard", "i_lionnight", "i_kions", "i_ski", 
"i_greenhouse", "i_ktirio", "i_tools", "i_toy", "i_bridger", "i_lionday", "i_brooklyn", "i_crownday",
 "i_londonbridge", "i_greentea", "i_leuven", "i_nuts", "i_resort", "i_bologna", "i_duda", "i_dome",
 "i_pencils", "i_books", "i_parking", "i_salon"

分割方案:view

"test": 
"v_circus", "v_charing", "v_colors", "v_astronautis", "v_maskedman", "v_talent", "v_london", 
"v_underground", "v_coffeehouse", "v_calder", "v_grace", "v_yard", "v_dogman", "v_laptop", 
"v_eastsouth", "v_boat", "v_strand", "v_busstop", "v_artisans", "v_machines", "v_soldiers", 
"v_home", "v_wapping", "v_wounded", "v_weapons", "v_adam", "v_there", "v_vitro", "v_cartooncity", 
"v_abstract", "v_dirtywall", "v_beyus", "v_apprentices", "v_sunseason", "v_wall", "v_war",
 "v_bricks", "v_fest", "v_churchill", "v_blueprint", "v_tempera", "v_samples", "v_man", 
"v_bees", "v_pomegranate", "v_bip", "v_feast", "v_azzola", "v_woman", "v_yuri", "v_posters",
"v_bird", "v_graffiti", "v_bark", "v_wormhole", "v_tabletop", "v_courses", "v_birdwoman",
"v_gardens"

分割方案:full

"test":
"i_crownnight", "i_table", "i_objects", "i_nescafe", "i_nijmegen", "i_whitebuilding", "i_porta", 
"i_santuario", "i_dc", "i_castle", "i_steps", "i_contruction", "i_melon", "i_yellowtent", 
"i_miniature", "i_troulos", "i_veggies", "i_zion", "i_gonnenberg", "i_autannes", "i_boutique", 
"i_fruits", "i_pool", "i_fog", "i_fenis", "i_village", "i_ajuntament", "i_partyfood", "i_kurhaus",
 "i_school", "i_chestnuts", "i_smurf", "i_indiana", "i_pinard", "i_lionnight", "i_kions", 
"i_ski", "i_greenhouse", "i_ktirio", "i_tools", "i_toy", "i_bridger", "i_lionday", "i_brooklyn",
 "i_crownday", "i_londonbridge", "i_greentea", "i_leuven", "i_nuts", "i_resort", "i_bologna", 
"i_duda", "i_dome", "i_pencils", "i_books", "i_parking", "i_salon", "v_circus", "v_charing", 
"v_colors", "v_astronautis", "v_maskedman", "v_talent", "v_london", "v_underground", "v_coffeehouse",
 "v_calder", "v_grace", "v_yard", "v_dogman", "v_laptop", "v_eastsouth", "v_boat", "v_strand", 
"v_busstop", "v_artisans", "v_machines", "v_soldiers", "v_home", "v_wapping", "v_wounded", 
"v_weapons", "v_adam", "v_there", "v_vitro", "v_cartooncity", "v_abstract", "v_dirtywall", 
"v_beyus", "v_apprentices", "v_sunseason", "v_wall", "v_war", "v_bricks", "v_fest", "v_churchill",
 "v_blueprint", "v_tempera", "v_samples", "v_man", "v_bees", "v_pomegranate", "v_bip", "v_feast", 
"v_azzola", "v_woman", "v_yuri", "v_posters", "v_bird", "v_graffiti", "v_bark", "v_wormhole", 
"v_tabletop", "v_courses", "v_birdwoman", "v_gardens"

除了对数据集场景进行分割

在此基础上,还针对不同任务,分配了patches图像对,方便我们直接使用

1.verification补丁验证(inter不同事物之间; intra同一事物内部各部分之间;  neg--negative负例;  pos--positive正例)

三个文件

verif_pos_split,匹配正例, patches图像对都来自同一图像序列; 

verif_neg_intra_split,匹配负例,但是负样本来自于同一图像序列,但是特征点并不匹配的位置, 

verif_neg_inter_split,匹配负例,负样来自于不同图像序列

verif_pos_split-a.csv
verif_neg_intra_split-a.csv
verif_neg_inter_split-a.csv

verif_pos_split-b.csv
verif_neg_intra_split-b.csv
verif_neg_inter_split-b.csv

verif_pos_split-c.csv
verif_neg_intra_split-c.csv
verif_neg_inter_split-c.csv

verif_pos_split-illum.csv
verif_neg_intra_split-illum.csv
verif_neg_inter_split-illum.csv

verif_pos_split-view.csv
verif_neg_intra_split-view.csv
verif_neg_inter_split-view.csv

verif_pos_split-full.csv
verif_neg_intra_split-full.csv
verif_neg_inter_split-full.csv

2.matching图像匹配

3.retrieval补丁检索(queries查询集;  distractors干扰集)

查询集中的patches和distractors集中的distractors非常相似,通过这种方式看是否检索错误

retr_queries_split-a.csv
retr_distractors_split-a.csv

retr_queries_split-b.csv
retr_distractors_split-b.csv

retr_queries_split-c.csv
retr_distractors_split-c.csv

retr_queries_split-illum.csv
retr_distractors_split-illum.csv

retr_queries_split-view.csv
retr_distractors_split-view.csv

retr_queries_split-full.csv
retr_distractors_split-full.csv

分割方案a举例:

verification

(s1,t1,idx1,s2,t2,idx2  分别意思是场景name----图像序列id----patches的id, 1和2同理)

通过函数进行随机生成:

def gen_verif(seqs,split,N_pos=1e6,N_neg=1e6):
    np.random.seed(42)

    # positives
    s = np.random.choice(split['test'], int(N_pos))
    seq2len = seqs_lengths(seqs)
    s_N = [seq2len[k] for k in s]
    s_idx = np.array([np.random.choice(np.arange(k),2,replace=False) for k in s_N])
    s_type = np.array([np.random.choice(np.arange(5),2,replace=False) for k in s_idx])
    df = pd.DataFrame({'s1': pd.Series(s, dtype=object),\
                       's2': pd.Series(s, dtype=object),\
                       'idx1': pd.Series(s_idx[:,0], dtype=int) ,\
                       'idx2': pd.Series(s_idx[:,0], dtype=int) ,\
                       't1': pd.Series(s_type[:,0], dtype=int) ,\
                       't2': pd.Series(s_type[:,1], dtype=int)})
    df = df[['s1','t1','idx1','s2','t2','idx2']] # updated order for matlab comp.
    df.to_csv(os.path.join(tskdir, 'verif_pos_split-'+split['name']+'.csv'),index=False)

    # intra-sequence negatives
    df = pd.DataFrame({'s1': pd.Series(s, dtype=object),\
                       's2': pd.Series(s, dtype=object),\
                       'idx1': pd.Series(s_idx[:,0], dtype=int) ,\
                       'idx2': pd.Series(s_idx[:,1], dtype=int) ,\
                       't1': pd.Series(s_type[:,0], dtype=int) ,\
                       't2': pd.Series(s_type[:,1], dtype=int)})
    df = df[['s1','t1','idx1','s2','t2','idx2']] # updated order for matlab comp.
    df.to_csv(os.path.join(tskdir, 'verif_neg_intra_split-'+split['name']+'.csv'),index=False)

    # inter-sequence negatives
    s_inter = np.random.choice(split['test'], int(N_neg))
    s_N_inter = [seq2len[k] for k in s_inter]
    s_idx_inter = np.array([np.random.randint(k) for k in s_N_inter])
    df = pd.DataFrame({'s1': pd.Series(s, dtype=object),\
                       's2': pd.Series(s_inter, dtype=object),\
                       'idx1': pd.Series(s_idx[:,0], dtype=int) ,\
                       'idx2': pd.Series(s_idx_inter, dtype=int) ,\
                       't1': pd.Series(s_type[:,0], dtype=int) ,\
                       't2': pd.Series(s_type[:,1], dtype=int)})
    df = df[['s1','t1','idx1','s2','t2','idx2']] # updated order for matlab comp.
    df.to_csv(os.path.join(tskdir, 'verif_neg_inter_split-'+split['name']+'.csv'),index=False)

    

每个场景图像序列是0,1,2,3,4,5, 通过字典进行转换,获得对应的图像patches集合:

id2t = {0:{'e':'ref','h':'ref','t':'ref'}, \
        1:{'e':'e1','h':'h1','t':'t1'}, \
        2:{'e':'e2','h':'h2','t':'t2'}, \
        3:{'e':'e3','h':'h3','t':'t3'}, \
        4:{'e':'e4','h':'h4','t':'t4'}, \
        5:{'e':'e5','h':'h5','t':'t5'} }

但其实我没看到5,也通过代码没有发现5,哈哈哈这里不知道咋回事,估计他们生成这个默认json的时候少算了一个?

#####################
# Verification task #
#####################
def get_verif_dists(descr,pairs,op):
    '''
    pairs为载入的csv文件
    verif_pos_split-a.csv
    verif_neg_intra_split-a.csv
    verif_neg_inter_split-a.csv
    :param descr:
    :param pairs:
    :param op:
    :return:
    '''
    d = {}
    for t in ['e','h','t']:
        d[t] = np.empty((pairs.shape[0],1))
    idx = 0
    pbar = tqdm(pairs)
    pbar.set_description("Processing verification task %i/3 " % op)
    #遍历100万个匹配对or不匹配对
    for p in pbar:
        [t1,t2] = [id2t[p[1]],id2t[p[4]]]
        #debug--看看有没有5
        if p[1]==5 or p[4]==5:
            print(p)
        #遍历tp = ['e','h','t']
        for t in tp:
            #拿到对应patches的描述子
            d1 = getattr(descr[p[0]], t1[t])[p[2]]
            d2 = getattr(descr[p[3]], t2[t])[p[5]]
            #计算描述子距离
            distance = descr['distance']
            if distance=='L2':
                dist = spatial.distance.euclidean(d1, d2)
            elif distance=='L1':
                dist = spatial.distance.cityblock(d1, d2)
            elif distance=='masked_L1':
                [d1,m1] = np.array_split(d1, 2)
                [d2,m2] = np.array_split(d2, 2)
                dist = spatial.distance.cityblock(m1*d1, m2*d2)
            d[t][idx] = dist
        idx+=1
    #返回的d中保存描述子距离,3组(EASY-HARD-TOUGH),每组100万
    return d

 然后分别计算平衡样本的

fpr
tpr
auc

不平衡样本的:

pr
rc
ap

def eval_verification(descr,split):
    print('>> Evaluating %s task' % green('verification'))

    start = time.time()
    pos = pd.read_csv(os.path.join(tskdir, 'verif_pos_split-'+split['name']+'.csv')).as_matrix()
    neg_intra = pd.read_csv(os.path.join(tskdir, 'verif_neg_intra_split-'+split['name']+'.csv')).as_matrix()
    neg_inter = pd.read_csv(os.path.join(tskdir, 'verif_neg_inter_split-'+split['name']+'.csv')).as_matrix()

    d_pos = get_verif_dists(descr,pos,1)
    d_neg_intra = get_verif_dists(descr,neg_intra,2)
    d_neg_inter = get_verif_dists(descr,neg_inter,3)

    results = defaultdict(lambda: defaultdict(lambda:defaultdict(dict)))

    for t in tp:
        l = np.vstack((np.zeros_like(d_pos[t]),np.ones_like(d_pos[t])))
        d_intra = np.vstack((d_neg_intra[t],d_pos[t]))
        d_inter = np.vstack((d_neg_inter[t],d_pos[t]))

        # get results for the balanced protocol: 1M Positives - 1M Negatives
        fpr,tpr,auc = metrics.roc(-d_intra,l)
        results[t]['intra']['balanced']['fpr'] = fpr
        results[t]['intra']['balanced']['tpr'] = tpr
        results[t]['intra']['balanced']['auc'] = auc

        fpr,tpr,auc = metrics.roc(-d_inter,l)
        results[t]['inter']['balanced']['fpr'] = fpr
        results[t]['inter']['balanced']['tpr'] = tpr
        results[t]['inter']['balanced']['auc'] = auc

        # get results for the imbalanced protocol: 0.2M Positives - 1M Negatives
        N_imb = d_pos[t].shape[0] + int(d_pos[t].shape[0]*0.2) # 1M + 0.2*1M
        pr,rc,ap = metrics.pr(-d_intra[0:N_imb],l[0:N_imb])
        results[t]['intra']['imbalanced']['pr'] = pr
        results[t]['intra']['imbalanced']['rc'] = rc
        results[t]['intra']['imbalanced']['ap'] = ap

        pr,rc,ap = metrics.pr(-d_inter[0:N_imb],l[0:N_imb])
        results[t]['inter']['imbalanced']['pr'] = pr
        results[t]['inter']['imbalanced']['rc'] = rc
        results[t]['inter']['imbalanced']['ap'] = ap
    end = time.time()
    print(">> %s task finished in %.0f secs  " % (green('Verification'),end-start))
    return results

具体指标计算代码如下: 

import numpy as np

# TODO: add documentation

def tpfp(scores,labels,numpos=None):
    # count labels
    p = int(np.sum(labels))
    n = len(labels)-p

    if numpos is not None:
        assert(numpos>=p), \
            'numpos smaller that number of positives in labels'
        extra_pos = numpos-p
        p = numpos
        scores = np.hstack((scores,np.repeat(-np.inf, extra_pos)))
        labels = np.hstack((labels,np.repeat(1, extra_pos)))
    
    perm = np.argsort(-scores, kind='mergesort',axis=0)
    
    scores = scores[perm]
    # assume that data with -INF score is never retrieved
    stop = np.max(np.where(scores > -np.inf))

    perm = perm[0:stop+1]

    labels = labels[perm]
    # accumulate true positives and false positives by scores    
    tp = np.hstack((0, np.cumsum(labels == 1)))
    fp = np.hstack((0, np.cumsum(labels == 0)))

    return tp,fp,p,n,perm

def pr(scores,labels,numpos=None):
    [tp,fp,p,n,perm] = tpfp(scores,labels,numpos)
    
    # compute precision and recall
    small = 1e-10
    recall = tp / float(np.maximum(p, small))
    precision = np.maximum(tp, small) / np.maximum(tp+fp, small)

    return precision,recall,np.trapz(precision,recall)


def roc(scores,labels,numpos=None):
    [tp,fp,p,n,perm] = tpfp(scores,labels,numpos)
    
    # compute tpr and fpr
    small = 1e-10
    tpr = tp / float(np.maximum(p, small))
    fpr = fp / float(np.maximum(n, small))

    return fpr,tpr,np.trapz(tpr,fpr)

 

retrieval

(s, idx  分别意思是选择想要检索patches的    场景name----patches-id)

  

通过函数进行随机生成:

def gen_retrieval(seqs,split,N_queries=0.5*1e4,N_distractors=2*1e4):
    np.random.seed(42)
    seq2len = seqs_lengths(seqs)

    s_q = np.random.choice(split['test'], int(N_queries*4))
    s_q_N = [seq2len[k] for k in s_q]
    s_q_idx = [np.random.randint(k) for k in s_q_N]
    s_q_idx = np.array(s_q_idx)

    s_d = np.random.choice(split['test'], int(N_distractors*10))
    s_d_N = [seq2len[k] for k in s_d]
    s_d_idx = [np.random.randint(k) for k in s_d_N]
    s_d_idx = np.array(s_d_idx)

    msk = np.zeros((s_q.shape[0],))
    for i in range(s_q.shape[0]):
        p = get_patch(seqs[s_q[i]],'ref',s_q_idx[i])
        if np.std(p)> 10 :
            msk[i] = 1

    msk = np.where(msk==1)
    s_q = s_q[msk]
    s_q_idx = s_q_idx[msk]

    msk = np.zeros((s_d.shape[0],))
    for i in range(s_d.shape[0]):
        p = get_patch(seqs[s_d[i]],'ref',s_d_idx[i])
        if np.std(p)> 10 :
            msk[i] = 1

    msk = np.where(msk==1)
    s_d = s_d[msk]
    s_d_idx = s_d_idx[msk]
    q_  = np.stack((s_q, s_q_idx), axis=-1)
    d_  = np.stack((s_d, s_d_idx), axis=-1)

    df_q = pd.DataFrame({'s': pd.Series(q_[:,0], dtype=object),\
                       'idx': pd.Series(q_[:,1], dtype=int)})
    df_q = df_q[['s','idx']] # updated order for matlab comp.

    df_d = pd.DataFrame({'s': pd.Series(d_[:,0], dtype=object),\
                       'idx': pd.Series(d_[:,1], dtype=int)})
    df_d = df_d[['s','idx']] # updated order for matlab comp.


    print(df_q.shape,df_d.shape)
    df_q = df_q.drop_duplicates()
    df_d = df_d.drop_duplicates()
    df_q = df_q.head(N_queries)
    df_q_ = df_q.copy()

    common = df_q_.merge(df_d,on=['s','idx'])
    # print(common.shape)
    df_q_.set_index(['s', 'idx'], inplace=True)
    df_d.set_index(['s', 'idx'], inplace=True)
    df_d = df_d[~df_d.index.isin(df_q_.index)].reset_index()
    # print(df_q.shape,df_d.shape)
    df_d = df_d.head(N_distractors)

    df_q.to_csv(os.path.join(tskdir, 'retr_queries_split-'+split['name']+'.csv'),index=False)
    df_d.to_csv(os.path.join(tskdir, 'retr_distractors_split-'+split['name']+'.csv'),index=False)

python画图

sudo apt install texlive-latex-base

sudo apt install texlive-latex-extra

hpatches-benchmark-master/python/utils中新加2个文件:

config.py

"""Code for configuring the appearance of the HPatches results figure.

For each descriptor, a name and a colour that will be used in the
figure can be configured.

You can add new descriptors as shown example below:
'desc':
Descriptor(name='Desc++', color='darksalmon'),

This will add a new descriptor, using the results from the `desc`
folder, with name appearing in the figure as `Desc++`, and
darksalmon colour.

The colour string, has to be a valid names colour, from the following
list:
https://matplotlib.org/examples/color/named_colors.html

Note that `new_descriptor` should match the name of the folder
containing the HPatches results.
"""
import collections

Descriptor = collections.namedtuple('Descriptor', 'name color')
desc_info = {
    'sift': Descriptor(name='SIFT', color='seagreen'),
    'rootsift': Descriptor(name='RSIFT', color='olive'),
    'orb': Descriptor(name='ORB', color='skyblue'),
    'brief': Descriptor(name='BRIEF', color='darkcyan'),
    'binboost': Descriptor(name='BBoost', color='steelblue'),
    'tfeat-liberty': Descriptor(name='TFeat-LIB', color='teal'),
    'geodesc': Descriptor(name='GeoDesc', color='tomato'),
    'hardnet-liberty': Descriptor(name='HNet-LIB', color='chocolate'),
    'deepdesc-ubc': Descriptor(name='DDesc-LIB', color='black'),
    'NCC': Descriptor(name='LearnedSIFT', color='blue'),
    # Add you own descriptors as below:
    # 'desc':
    # Descriptor(name='Desc++', color='darksalmon'),
}

# Symbols for the figure
figure_attributes = {
    'intra_marker': ".",
    'inter_marker': "d",
    'viewp_marker': "*",
    'illum_marker': r"$\diamond$",
    'easy_colour': 'green',
    'hard_colour': "purple",
    'tough_colour': "red",
}

vis_results.py

import dill
import os.path
import numpy as np
from collections import defaultdict
import matplotlib.pyplot as plt
import matplotlib.lines as mlines
from utils.config import desc_info, figure_attributes
import itertools
import operator
import matplotlib.patches as patches

ft = {'e': 'Easy', 'h': 'Hard', 't': 'Tough'}


class DescriptorMatchingResult:
    def __init__(self, desc, splt):
        matching_results = defaultdict(
            lambda: defaultdict(lambda: defaultdict(dict)))
        res = dill.load(
            open(
                os.path.join("results",
                             desc + "_matching_" + splt['name'] + ".p"), "rb"))

        for seq in res:
            seq_type = seq.split("_")[0]
            for t in ft.keys():
                APs = [res[seq][t][idx]['ap'] for idx in range(1, 6)]
                mAP = np.mean(APs)
                matching_results[t][seq_type][seq] = mAP

        cases = list(itertools.product(ft.keys(), ['v', 'i']))
        setattr(self, "avg_v", 0)
        setattr(self, "avg_i", 0)

        for itm in cases:
            noise_type, seq_type = itm[0], itm[1]
            val = 100 * np.mean(
                list(matching_results[noise_type][seq_type].values()))
            setattr(self, str(itm), val)
            setattr(self, "avg_" + seq_type,
                    getattr(self, "avg_" + seq_type) + val)

        n_samples = float(len(cases)) / 2.0
        setattr(self, "avg_i", getattr(self, "avg_i") / n_samples)
        setattr(self, "avg_v", getattr(self, "avg_v") / n_samples)
        setattr(self, "avg",
                (getattr(self, "avg_i") + getattr(self, "avg_v")) / 2.0)


class DescriptorRetrievalResult:
    def __init__(self, desc, splt):
        res = dill.load(
            open(
                os.path.join("results",
                             desc + "_retrieval_" + splt['name'] + ".p"),
                "rb"))

        retrieval_results = defaultdict(lambda: defaultdict(dict))
        pool_sizes = [100, 500, 1000, 5000, 10000, 15000, 20000]
        for psize in pool_sizes:
            for t in ft.keys():
                retrieval_results[t][psize] = []

        for q_idx in res.keys():
            for t in ft.keys():
                for psize in pool_sizes:
                    retrieval_results[t][psize].append(
                        100 * res[q_idx][t][psize]['ap'])

        for t in ft.keys():
            avg_t = 0
            for psize in pool_sizes:
                avg_t += np.mean(retrieval_results[t][psize])
            setattr(self, t, avg_t / float(len(pool_sizes)))
        setattr(self, "avg",
                (getattr(self, "e") + getattr(self, "h") + getattr(self, "t"))
                / 3.0)


class DescriptorVerificationResult:
    def __init__(self, desc, splt):
        metric = {'balanced': 'auc', 'imbalanced': 'ap'}
        self.desc = desc
        self.splt = splt

        res = dill.load(
            open(
                os.path.join(
                    "results",
                    self.desc + "_verification_" + self.splt['name'] + ".p"),
                "rb"))

        cases = list(
            itertools.product(ft.keys(), ['intra', 'inter'],
                              ['balanced', 'imbalanced']))

        setattr(self, "avg_balanced", 0)
        setattr(self, "avg_imbalanced", 0)

        for itm in cases:
            noise_type, negs_type, balance_type = itm[0], itm[1], itm[2]
            val = 100 * res[noise_type][negs_type][balance_type][
                metric[balance_type]]
            setattr(self, str(itm), val)
            setattr(self, "avg_" + balance_type,
                    getattr(self, "avg_" + balance_type) + val)

        n_samples = float(len(cases)) / 2.0
        setattr(self, "avg_balanced",
                getattr(self, "avg_balanced") / n_samples)
        setattr(self, "avg_imbalanced",
                getattr(self, "avg_imbalanced") / n_samples)


class DescriptorHPatchesResult:
    def __init__(self, desc, splt):
        self.desc = desc
        self.splt = splt
        self.verification = DescriptorVerificationResult(desc, splt)
        self.matching = DescriptorMatchingResult(desc, splt)
        self.retrieval = DescriptorRetrievalResult(desc, splt)


def plot_verification(hpatches_results, ax, **kwargs):
    hpatches_results.sort(
        key=operator.attrgetter('verification.avg_imbalanced'), reverse=True)
    descrs = [x.desc for x in hpatches_results]
    y_pos = np.arange(len(descrs))

    avg_verifs = [
        getattr(x.verification, "avg_imbalanced") for x in hpatches_results
    ]

    cases = list(
        itertools.product(ft.keys(), ['intra', 'inter'], ['imbalanced']))
    verification_results = {}
    for case in cases:
        case_results = []
        for descr_result in hpatches_results:
            case_results.append(getattr(descr_result.verification, str(case)))
        verification_results[str(case)] = case_results

    ax.set_axisbelow(True)
    ax.xaxis.grid(color='gray', linestyle='dashed')
    ax.yaxis.grid(color='gray', linestyle='dashed')

    ax.set_yticks(y_pos)
    ax.set_yticklabels([desc_info[x].name for x in descrs], fontsize=14)
    ax.tick_params(axis='both', which='both', length=0)
    ax.set_xticklabels(np.arange(0, 101, 20), fontsize=12)

    ax.barh(
        y_pos,
        avg_verifs,
        color=[desc_info[x].color for x in descrs],
        edgecolor='k',
        linewidth=1.5,
        alpha=0.8)

    ax.plot(
        verification_results[str(('e', 'inter', 'imbalanced'))],
        y_pos,
        marker=figure_attributes['inter_marker'],
        linestyle="",
        alpha=0.8,
        color=figure_attributes['easy_colour'])
    ax.plot(
        verification_results[str(('e', 'intra', 'imbalanced'))],
        y_pos,
        marker=figure_attributes['intra_marker'],
        linestyle="",
        alpha=0.8,
        color=figure_attributes['easy_colour'])
    ax.plot(
        verification_results[str(('h', 'inter', 'imbalanced'))],
        y_pos,
        marker=figure_attributes['inter_marker'],
        linestyle="",
        alpha=0.8,
        color=figure_attributes['hard_colour'])
    ax.plot(
        verification_results[str(('h', 'intra', 'imbalanced'))],
        y_pos,
        marker=figure_attributes['intra_marker'],
        linestyle="",
        alpha=0.8,
        color=figure_attributes['hard_colour'])
    ax.plot(
        verification_results[str(('t', 'inter', 'imbalanced'))],
        y_pos,
        marker=figure_attributes['inter_marker'],
        linestyle="",
        alpha=0.8,
        color=figure_attributes['tough_colour'])
    ax.plot(
        verification_results[str(('t', 'intra', 'imbalanced'))],
        y_pos,
        marker=figure_attributes['intra_marker'],
        linestyle="",
        alpha=0.8,
        color=figure_attributes['tough_colour'])
    ax.set_xlim([0, 100])

    for i, v in enumerate(avg_verifs):
        ax.text(
            101,
            y_pos[i] - 0.13,
            r"{:.2f}".format(v),
            color='black',
            fontsize=14)

    inter_symbol = mlines.Line2D([], [],
                                 color='black',
                                 marker=figure_attributes['inter_marker'],
                                 linestyle='None',
                                 markersize=5,
                                 label=r'\textsc{Inter}')
    intra_symbol = mlines.Line2D([], [],
                                 color='black',
                                 marker=figure_attributes['intra_marker'],
                                 linestyle='None',
                                 markersize=5,
                                 label=r'\textsc{Intra}')
    ax.legend(
        handles=[inter_symbol, intra_symbol],
        loc='lower center',
        ncol=2,
        bbox_to_anchor=(0.34, 1, .3, .0),
        handletextpad=-0.5,
        columnspacing=0,
        fontsize=12,
        frameon=False)

    return ax


def plot_matching(hpatches_results, ax, **kwargs):
    hpatches_results.sort(
        key=operator.attrgetter('matching.avg'), reverse=True)
    descrs = [x.desc for x in hpatches_results]
    y_pos = np.arange(len(descrs))

    avg_verifs = [getattr(x.matching, "avg") for x in hpatches_results]

    cases = list(itertools.product(ft.keys(), ['v', 'i']))
    matching_results = {}
    for case in cases:
        case_results = []
        for descr_result in hpatches_results:
            case_results.append(getattr(descr_result.matching, str(case)))
        matching_results[str(case)] = case_results

    ax.set_axisbelow(True)
    ax.xaxis.grid(color='gray', linestyle='dashed')
    ax.yaxis.grid(color='gray', linestyle='dashed')

    ax.set_yticks(y_pos)
    ax.set_yticklabels([desc_info[x].name for x in descrs], fontsize=14)
    ax.tick_params(axis='both', which='both', length=0)
    ax.set_xticklabels(np.arange(0, 101, 20), fontsize=12)

    ax.barh(
        y_pos,
        avg_verifs,
        color=[desc_info[x].color for x in descrs],
        edgecolor='k',
        linewidth=1.5,
        alpha=0.8)

    ax.plot(
        matching_results[str(('e', 'v'))],
        y_pos,
        marker=figure_attributes['viewp_marker'],
        linestyle="",
        alpha=0.8,
        color=figure_attributes['easy_colour'])
    ax.plot(
        matching_results[str(('e', 'i'))],
        y_pos,
        marker=figure_attributes['illum_marker'],
        linestyle="",
        alpha=0.8,
        color=figure_attributes['easy_colour'])
    ax.plot(
        matching_results[str(('h', 'v'))],
        y_pos,
        marker=figure_attributes['viewp_marker'],
        linestyle="",
        alpha=0.8,
        color=figure_attributes['hard_colour'])
    ax.plot(
        matching_results[str(('h', 'i'))],
        y_pos,
        marker=figure_attributes['illum_marker'],
        linestyle="",
        alpha=0.8,
        color=figure_attributes['hard_colour'])
    ax.plot(
        matching_results[str(('t', 'v'))],
        y_pos,
        marker=figure_attributes['viewp_marker'],
        linestyle="",
        alpha=0.8,
        color=figure_attributes['tough_colour'])
    ax.plot(
        matching_results[str(('t', 'i'))],
        y_pos,
        marker=figure_attributes['illum_marker'],
        linestyle="",
        alpha=0.8,
        color=figure_attributes['tough_colour'])
    ax.set_xlim([0, 100])

    for i, v in enumerate(avg_verifs):
        ax.text(
            101,
            y_pos[i] - 0.13,
            r"{:.2f}".format(v),
            color='black',
            fontsize=14)

    view_symbol = mlines.Line2D([], [],
                                color='black',
                                marker=figure_attributes['viewp_marker'],
                                linestyle='None',
                                markersize=5,
                                label=r'\textsc{Viewp}')
    illum_symbol = mlines.Line2D([], [],
                                 color='black',
                                 marker=figure_attributes['illum_marker'],
                                 linestyle='None',
                                 markersize=5,
                                 label=r'\textsc{Illum}')
    ax.legend(
        handles=[view_symbol, illum_symbol],
        loc='lower center',
        ncol=2,
        bbox_to_anchor=(0.34, 1, .3, .0),
        handletextpad=-0.5,
        columnspacing=0,
        fontsize=12,
        frameon=False)
    return ax


def plot_retrieval(hpatches_results, ax, **kwargs):
    hpatches_results.sort(
        key=operator.attrgetter('retrieval.avg'), reverse=True)
    descrs = [x.desc for x in hpatches_results]
    y_pos = np.arange(len(descrs))

    avg_verifs = [getattr(x.retrieval, "avg") for x in hpatches_results]

    retrieval_results = {}
    for case in ft.keys():
        case_results = []
        for descr_result in hpatches_results:
            case_results.append(getattr(descr_result.retrieval, str(case)))
        retrieval_results[str(case)] = case_results

    ax.set_axisbelow(True)
    ax.xaxis.grid(color='gray', linestyle='dashed')
    ax.yaxis.grid(color='gray', linestyle='dashed')

    ax.set_yticks(y_pos)
    ax.set_yticklabels([desc_info[x].name for x in descrs], fontsize=14)
    ax.tick_params(axis='both', which='both', length=0)
    ax.set_xticklabels(np.arange(0, 101, 20), fontsize=12)

    ax.barh(
        y_pos,
        avg_verifs,
        color=[desc_info[x].color for x in descrs],
        edgecolor='k',
        linewidth=1.5,
        alpha=0.8)

    ax.plot(
        retrieval_results[str(('e'))],
        y_pos,
        marker="o",
        linestyle="",
        alpha=0.8,
        markersize=4,
        color=figure_attributes['easy_colour'])
    ax.plot(
        retrieval_results[str(('h'))],
        y_pos,
        marker="o",
        linestyle="",
        alpha=0.8,
        markersize=4,
        color=figure_attributes['hard_colour'])
    ax.plot(
        retrieval_results[str(('t'))],
        y_pos,
        marker="o",
        linestyle="",
        alpha=0.8,
        markersize=4,
        color=figure_attributes['tough_colour'])
    ax.set_xlim([0, 100])

    for i, v in enumerate(avg_verifs):
        ax.text(
            101,
            y_pos[i] - 0.13,
            r"{:.2f}".format(v),
            color='black',
            fontsize=14)

    return ax


def plot_hpatches_results(hpatches_results):
    plt.rc('text', usetex=True)
    plt.rc('font', family='serif')
    plt.rc('text.latex', preamble=r'\usepackage{amssymb} \usepackage{color}')
    n_descrs = len(hpatches_results)
    # The height of the plot for descriptors depend on number of descriptors
    # The 0.8 is absolute value for blank space at the bottom
    bar_width = 0.3
    descr_height = n_descrs * bar_width + 0.8
    
    # Figure height is descriptor plot height plus fixed 1.2 for header
    figh = 1.2 + descr_height
    f, (ax_verification, ax_matching, ax_retrieval) = plt.subplots(1, 3)
    f.set_size_inches(15, figh)
    f.suptitle(
        r'{\bf $\mathbb{H}$Patches Results}', fontsize=22, x=0.5, y=0.98)

    easy_marker = mlines.Line2D([], [],
                                color=figure_attributes['easy_colour'],
                                marker='s',
                                linestyle='None',
                                markersize=4,
                                label=r'\textsc{Easy}')
    hard_marker = mlines.Line2D([], [],
                                color=figure_attributes['hard_colour'],
                                marker='s',
                                linestyle='None',
                                markersize=4,
                                label=r'\textsc{Hard}')
    tough_marker = mlines.Line2D([], [],
                                 color=figure_attributes['tough_colour'],
                                 marker='s',
                                 linestyle='None',
                                 markersize=4,
                                 label=r'\textsc{Tough}')
    plt.figlegend(
        handles=[easy_marker, hard_marker, tough_marker],
        loc='lower center',
        ncol=3,
        bbox_to_anchor=(0.45, 1-0.8/figh, 0.1, 0.0),
        handletextpad=-0.5,
        fontsize=12,
        columnspacing=0)

    # As this is relative, to have fixed bottom/top space we divide by figh
    # We will have then bottom_margin = bottom * figh = 0.8
    # Same for top, top_margin = figh*(1-top) = figh - descr_size = 1.2
    # And the descriptor plot height will then be 
    # figh - top_margin - bottom_margin = n_descrs * bar_height
    # So descriptor plot height will be directly proportional to n_descr
    plt.subplots_adjust(
        left=0.07, bottom=(0.8 / figh), right=None, top=(descr_height / figh), 
        wspace=0.7, hspace=None)

    plot_verification(hpatches_results, ax_verification)
    plot_matching(hpatches_results, ax_matching)
    plot_retrieval(hpatches_results, ax_retrieval)

    ax_verification.set_xlabel(r'Patch Verification mAP [\%]', fontsize=15)
    ax_matching.set_xlabel(r'Image Matching mAP [\%]', fontsize=15)
    ax_retrieval.set_xlabel(r'Patch Retrieval mAP [\%]', fontsize=15)

    f.savefig('hpatches_results.pdf')

hpatches-benchmark-master/python/中新加1个文件:

vis_hpatches_results.py

"""Code for printing/plotting results for the HPatches evaluation protocols.

Usage:
  hpatches_results.py (-h | --help)
  hpatches_results.py --version
  hpatches_results.py --descr-name=<>...
                      [--results-dir=<>] [--split=<>] [--pcapl=<>]

Options:
  -h --help         Show this screen.
  --version         Show version.
  --descr-name=<>   Descriptor name e.g. --descr=sift.
  --results-dir=<>  Results root folder. [default: results]
  --split=<>        Split name. Valid are {a,b,c,full,illum,view}. [default: a]

For more visit: https://github.com/hpatches/
"""
from utils.tasks import tskdir
from utils.vis_results import plot_hpatches_results
from utils.vis_results import DescriptorHPatchesResult
import os.path
import json
from utils.docopt import docopt

if __name__ == '__main__':
    opts = docopt(__doc__, version='HPatches 1.0')
    descrs = opts['--descr-name']

    with open(os.path.join(tskdir, "splits", "splits.json")) as f:
        splits = json.load(f)
    splt = splits[opts['--split']]

    hpatches_results = []
    for desc in descrs:
        hpatches_results.append(DescriptorHPatchesResult(desc, splt))

    plot_hpatches_results(hpatches_results)

运行

python vis_hpatches_results.py --descr="sift"  --descr="rootsift"  --descr="orb"  --results-dir="./results"  --split="a"  --pcapl="no"

 

 

 

  • 3
    点赞
  • 21
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值