HF-Net(二)基于HF-Net的全局特征定位及局部特征匹配

参考:HF-Net git地址

0.整体架构

图片来源:https://arxiv.org/abs/1812.03506
1.核心代码

import cv2
import numpy as np
from pathlib import Path

from hfnet.settings import EXPER_PATH
from notebooks.utils import plot_images, plot_matches, add_frame

import tensorflow as tf
from tensorflow.python.saved_model import tag_constants
tf.contrib.resampler  

from tensorflow.compat.v1 import ConfigProto
from tensorflow.compat.v1 import InteractiveSession

config = ConfigProto()
config.gpu_options.allow_growth = True
session = InteractiveSession(config=config)

class HFNet:
    def __init__(self, model_path, outputs):
        self.session = tf.Session()
        self.image_ph = tf.placeholder(tf.float32, shape=(None, None, 3))

        net_input = tf.image.rgb_to_grayscale(self.image_ph[None])
        tf.saved_model.loader.load(
            self.session, [tag_constants.SERVING], str(model_path),
            clear_devices=True,
            input_map={'image:0': net_input})

        graph = tf.get_default_graph()
        self.outputs = {n: graph.get_tensor_by_name(n+':0')[0] for n in outputs}
        self.nms_radius_op = graph.get_tensor_by_name('pred/simple_nms/radius:0')
        self.num_keypoints_op = graph.get_tensor_by_name('pred/top_k_keypoints/k:0')
        
    def inference(self, image, nms_radius=4, num_keypoints=1000):
        inputs = {
            self.image_ph: image[..., ::-1].astype(np.float),
            self.nms_radius_op: nms_radius,
            self.num_keypoints_op: num_keypoints,
        }
        return self.session.run(self.outputs, feed_dict=inputs)

def compute_distance(desc1, desc2):
    return 2 * (1 - desc1 @ desc2.T)

def match_with_ratio_test(desc1, desc2, thresh):
    dist = compute_distance(desc1, desc2)
    nearest = np.argpartition(dist, 2, axis=-1)[:, :2]
    dist_nearest = np.take_along_axis(dist, nearest, axis=-1)
    valid_mask = dist_nearest[:, 0] <= (thresh**2)*dist_nearest[:, 1]
    matches = np.stack([np.where(valid_mask)[0], nearest[valid_mask][:, 0]], 1)
    return matches 

if __name__ == "__main__":
    query_idx = 1  
    read_image = lambda n: cv2.imread('doc/demo/' + n)[:, :, ::-1]
    image_query = read_image(f'query{query_idx}.jpg')
    images_db = [read_image(f'db{i}.jpg') for i in range(1, 5)]
    plot_images([image_query] + images_db, dpi=50)

    model_path = Path(EXPER_PATH, 'saved_models/hfnet')
    outputs = ['global_descriptor', 'keypoints', 'local_descriptors']
    hfnet = HFNet(model_path, outputs)

    db = [hfnet.inference(i) for i in images_db]
    global_index = np.stack([d['global_descriptor'] for d in db])
    query = hfnet.inference(image_query)

    nearest = np.argmin(compute_distance(query['global_descriptor'], global_index))
    disp_db = [add_frame(im, (0, 255, 0)) if i == nearest else im
           for i, im in enumerate(images_db)]
    #plot_images([image_query] + disp_db, dpi=50)

    matches = match_with_ratio_test(query['local_descriptors'],
                                db[nearest]['local_descriptors'], 0.8)
    print(nearest)
    print(len(matches))

    plot_matches(image_query, query['keypoints'],
             images_db[nearest], db[nearest]['keypoints'],
            matches, color=(0, 1, 0), dpi=50)

1.1参数设置
设置查询底库路径,设置查询影像,在编译hf-net时设置的EXPER_PATH路径下创建saved_models文件夹,其下存放hf-net预训练权重
1.2功能实现
底库及查询影像经过hfnet网络后生成’global_descriptor’, ‘keypoints’, ‘local_descriptors’,基于global_descriptor进行粗定位,基于local_descriptors进行精细匹配

2.演示结果

2.1样例查询
底库图像
在这里插入图片描述
查询图像
在这里插入图片描述
基于global_descriptor进行粗定位
在这里插入图片描述
基于local_descriptors进行精细匹配
在这里插入图片描述

评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值