论文解析一: SuperPoint 一种自监督网络框架,能够同时提取特征点的位置以及描述子

SuperPoint:一种自监督网络框架,能够同时提取特征点的位置以及描述子

**传统方法的问题:**基于图像块的算法导致特征点位置精度不够准确;特征点与描述子分开进行训练导致运算资源的浪费,网络不够精简,实时性不足;仅仅训练特征点或者描述子的一种,不能用同一个网络进行联合训练。

如下图整个训练流程分为三部分:

(a)Inter Point Pre-Training(特征点预训练)

(b)Interest Point Self-Labeling(自监督标签)

(c)Joint Training (联合训练)

在这里插入图片描述

1.特征点预训练

​ 创建合成数据集 Synthetic Shapes(Labeled Interest Point Images ) 利用合成数据集训练的检测器成MagicPoint(Base Detector)虽然是在合成的数据集上进行训练的,但是论文中提到MagicPoint在Corner-Like Structure(角点类似结构)的现实场景也具备一定的泛化能力,而面对更加普遍的场景,MagicPoint的效果就会下降,为此作者增加了第二步,即Homographic Adaption,第二步 自监督标签

2.自监督标签

在这里插入图片描述
在这里插入图片描述

3.整体网络结构

SuperPoint的网络结构如下图所示:
在这里插入图片描述

3.1 先对图像进行卷积

图像特征通常包括两部分,特征点和特征描述子,上图中两个分支即分别同时提取特征点和特征描述子。网络首先使用了VGG-Style的Encoder用于降低图像尺寸提取特征,Encoder部分由卷积层、Max-Pooling层和非线性激活层组成,通过三个Max-Pooling层将图像尺寸变为输出的1/8,代码如下:

# Shared Encoder
x = self.relu(self.conv1a(data['image']))
x = self.relu(self.conv1b(x))
x = self.pool(x)
x = self.relu(self.conv2a(x))
x = self.relu(self.conv2b(x))
x = self.pool(x)
x = self.relu(self.conv3a(x))
x = self.relu(self.conv3b(x))
x = self.pool(x)
x = self.relu(self.conv4a(x))
x = self.relu(self.conv4b(x)) # x的输出维度是(N,128,W/8, H/8)
# N代表样本数量
# 128 表示特征图的通道数  代表输出 x 中包含了 128 个不同的特征图,每个特征图都捕获了输入数据中的不同特征信息
# W/8, H/8 特征图的宽度和高度相对于初始输入图片的尺寸已经缩小了 8 倍。

3.2 特征点提取部分(Interest Point Decoder)

​ 对于特征点提取部分,网络先将维度( W / 8 , H / 8 , 128 )的特征处理为( W / 8 , H / 8 , 65 )大小,这里的65的含义是特征图的每一个像素表示原图8 × 8 的局部区域加上一个当局部区域不存在特征点时用于输出的Dustbin通道,通过Softmax以及Reshape的操作,最终特征会恢复为原图大小。

​ 其中65 是 = 64 +1 先把128通道数 变为64通道数 ,再加一个参数层(Dustbin通道),这一层是为了8×8的局部区域内没有特征点时,经过Softmax后64维的特征势必还是会有一个相对较大的值输出,但加入Dustbin通道后就可以避免这个问题。

​ 然后对得到的特征图,进行**归一化(Softmax)和纬度变化(Reshape)**操作。

具体代码如下:

 # 计算密集关键点分数
 cPa = self.relu(self.convPa(x)) # x维度是(N,128,W/8, H/8)
 scores = self.convPb(cPa) # scores维度是(N,65,W/8, H/8)
 #下面这为什么由65 变为 64 了,是因为对 scores 进行 softmax(归一化操作)操作,得到概率分布,并且去掉最后一个通道
 #这样最后一个参数层便可以去掉了
 scores = torch.nn.functional.softmax(scores, 1)[:, :-1] # scores维度是(N,64,W/8, H/8)
 #获取 scores 的形状信息,在这里,scores 是一个四维张量,表示为(batch_size(数据的批量大小), channels, height, width)
 b, _, h, w = scores.shape
 scores = scores.permute(0, 2, 3, 1).reshape(b, h, w, 8, 8) # 进行纬度变化 scores维度是(N,W/8, H/8, 8, 8)
 scores = scores.permute(0, 1, 3, 2, 4).reshape(b, h*8, w*8) # 再次进行纬度变化 scores维度是(N,W/8, H/8)
 #对 scores 进行非极大值抑制(NMS)处理,为了进一步提取关键点或者去除冗余信息,并得到关键点分数
 scores = simple_nms(scores, self.config['nms_radius'])

3.3 特征描述子提取部分(Descriptor Decoder)

​ 对于特征描述子提取部分,同理,我们还是使用encoder层的输出(H,W,128)。经过卷积解码器得到(H,W,256),双线性插值扩大尺寸(H,W,256),最后对每一个像素的描述子(256维)进行L2归一化。

在这里插入图片描述

双三次插值(Bicubic Interpolation)是一种常用的图像处理和计算机图形学中的插值方法,用于在离散网格上对图像进行平滑的插值,它可以用于放大或缩小图像,并且相比线性插值更能保持图像细节和平滑度。

3.4 损失函数

首先我们看下基于特征点和特征向量是如何建立损失函数的,损失函数公式如下:

在这里插入图片描述

4.实验结果对比 和 不足

在这里插入图片描述

​ 可以看到第四行SuperPoint的表现其实很差,我们知道单应变化成立的前提条件是场景中存在平面,而第四行的场景很显然不太满足这种条件,因此效果也会比较差,这也体现了SuperPoint在自监督训练下的一些不足之处。

5.整体代码详解

from pathlib import Path
import torch
from torch import nn

#非极大值抑制(Non-maximum Suppression,NMS)算法,用于移除邻近的点以实现特征点的稀疏化
def simple_nms(scores, nms_radius: int):
    """ Fast Non-maximum suppression to remove nearby points """
    assert(nms_radius >= 0)

    def max_pool(x):
        return torch.nn.functional.max_pool2d(
            x, kernel_size=nms_radius*2+1, stride=1, padding=nms_radius)

    zeros = torch.zeros_like(scores)
    max_mask = scores == max_pool(scores)
    for _ in range(2):
        supp_mask = max_pool(max_mask.float()) > 0
        supp_scores = torch.where(supp_mask, zeros, scores)
        new_max_mask = supp_scores == max_pool(supp_scores)
        max_mask = max_mask | (new_max_mask & (~supp_mask))
    return torch.where(max_mask, scores, zeros)

#用于移除位于图像边界附近的关键点
def remove_borders(keypoints, scores, border: int, height: int, width: int):
    """ Removes keypoints too close to the border """
    mask_h = (keypoints[:, 0] >= border) & (keypoints[:, 0] < (height - border))
    mask_w = (keypoints[:, 1] >= border) & (keypoints[:, 1] < (width - border))
    mask = mask_h & mask_w
    return keypoints[mask], scores[mask]

#从关键点集合中选择具有最高分数的前 k 个关键点
def top_k_keypoints(keypoints, scores, k: int):
    if k >= len(keypoints):
        return keypoints, scores
    scores, indices = torch.topk(scores, k, dim=0)
    return keypoints[indices], scores

#在关键点位置对描述子进行插值采样  双线性差值
def sample_descriptors(keypoints, descriptors, s: int = 8):
    """ Interpolate descriptors at keypoint locations """
    b, c, h, w = descriptors.shape
    keypoints = keypoints - s / 2 + 0.5
    keypoints /= torch.tensor([(w*s - s/2 - 0.5), (h*s - s/2 - 0.5)],
                              ).to(keypoints)[None]
    keypoints = keypoints*2 - 1  # normalize to (-1, 1)
    args = {'align_corners': True} if int(torch.__version__[2]) > 2 else {}
    descriptors = torch.nn.functional.grid_sample(
        descriptors, keypoints.view(b, 1, -1, 2), mode='bilinear', **args)
    descriptors = torch.nn.functional.normalize(
        descriptors.reshape(b, c, -1), p=2, dim=1)
    return descriptors


class DetectNet(nn.Module):
    default_config = {
        'descriptor_dim': 256,
        'nms_radius': 4,
        'keypoint_threshold': 0.005,
        'max_keypoints': -1,
        'remove_borders': 4,
    }

    def __init__(self, config):
        super().__init__()
        #根据默认配置和传入的配置参数合并生成最终的配置
        self.config = {**self.default_config, **config}

        self.relu = nn.ReLU(inplace=True)# 定义一个使用inplace操作的ReLU激活函数
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)# 定义一个最大池化层,kernel_size为卷积核大小,stride为步长
        c1, c2, c3, c4, c5 = 64, 64, 128, 128, 256 # 定义了5个通道数变量
        
        # 定义一系列的卷积层
        self.conv1a = nn.Conv2d(1, c1, kernel_size=3, stride=1, padding=1)#输入数据的通道数,输出数据的通道数,即卷积核的数量
        self.conv1b = nn.Conv2d(c1, c1, kernel_size=3, stride=1, padding=1)
        self.conv2a = nn.Conv2d(c1, c2, kernel_size=3, stride=1, padding=1)
        self.conv2b = nn.Conv2d(c2, c2, kernel_size=3, stride=1, padding=1)
        self.conv3a = nn.Conv2d(c2, c3, kernel_size=3, stride=1, padding=1)
        self.conv3b = nn.Conv2d(c3, c3, kernel_size=3, stride=1, padding=1)
        self.conv4a = nn.Conv2d(c3, c4, kernel_size=3, stride=1, padding=1)
        self.conv4b = nn.Conv2d(c4, c4, kernel_size=3, stride=1, padding=1)

        # 定义两个卷积层,用于处理像素位置(P)和描述子(D)
        self.convPa = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1)
        self.convPb = nn.Conv2d(c5, 65, kernel_size=1, stride=1, padding=0)

        self.convDa = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1)
        self.convDb = nn.Conv2d(
            c5, self.config['descriptor_dim'],
            kernel_size=1, stride=1, padding=0)

        #声明全连接层
        # self.fc = nn.Linear(128 * 64 * 64, 128 * 64 * 64)  # 输入特征维度为 128*64*64,输出维度为 128 * 64 * 64

        #声明卷积层
        self.conv5a = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=1, stride=1, padding=0)

    
        path = self.config['weights'] #从模型配置中获取预训练权重文件的路径
        pretrained_dict = torch.load(str(path))#使用torch.load 函数加载预训练模型的权重参数,将其存储在 pretrained_dict 中。
        model_dict = self.state_dict()#获取当前模型的状态字典,即当前模型的权重参数。

        #遍历预训练模型的权重参数,只保留那些在当前模型状态字典中存在的键值对,生成一个新的字典 pretrained_dict,用于更新当前模型的参数。
        pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}

        # 更新字典
        model_dict.update(pretrained_dict)#pretrained_dict 中的权重参数更新到当前模型的状态字典 model_dict 中。
        self.load_state_dict(model_dict, strict=False)

        mk = self.config['max_keypoints']#获取配置中的最大关键点数
        if mk == 0 or mk < -1:#如果最大关键点数为 0 或者小于 -1,则抛出值错误。
            raise ValueError('\"max_keypoints\" must be positive or \"-1\"')

        print('Loaded DetectNet model')

    def forward(self, data):
        """ Compute keypoints, scores, descriptors for image """
        # Shared Encoder  共享编码器
        # x = self.relu(self.conv1a(data['image']))
        x = self.relu(self.conv1a(data))
        x = self.relu(self.conv1b(x))
        x = self.pool(x)
        x = self.relu(self.conv2a(x))
        x = self.relu(self.conv2b(x))
        x = self.pool(x)
        x = self.relu(self.conv3a(x))
        x = self.relu(self.conv3b(x))
        x = self.pool(x)
        x = self.relu(self.conv4a(x))
        x = self.relu(self.conv4b(x))#(1,128,64,64)

        x = self.relu(self.conv5a(x))  # (1,128,64,64)
        # # 将 x 展平为一维向量,以便作为全连接层的输入
        # x = x.view(x.size(0), -1)
        # #全连接层的输入
        # x = self.fc(x)
        # # 将 x 展开为(1, 128, 64, 64)的维度
        # x = x.view(1, 128, 64, 64)

        # Compute the dense keypoint scores 计算稠密关键点分数  Interest Point Decoder(特征点提取部分)
        cPa = self.relu(self.convPa(x))#(1,256,64,64)
        scores = self.convPb(cPa)#(1,65,64,64)
        scores = torch.nn.functional.softmax(scores, 1)[:, :-1]#(1,64,64,64)  #Softmax函数将分数转换为概率,然后删除最后一个通道,即无效通道
        b, _, h, w = scores.shape #
        scores = scores.permute(0, 2, 3, 1).reshape(b, h, w, 8, 8)
        scores = scores.permute(0, 1, 3, 2, 4).reshape(b, h*8, w*8)
        scores = simple_nms(scores, self.config['nms_radius'])#对 scores 进行非极大值抑制(NMS)处理,为了进一步提取关键点或者去除冗余信息,并得到关键点分数

        # Extract keypoints   提取关键点
        keypoints = [
            torch.nonzero(s > self.config['keypoint_threshold'])
            for s in scores]
        scores = [s[tuple(k.t())] for s, k in zip(scores, keypoints)]

        # Discard keypoints near the image borders  丢弃靠近图像边界的关键点
        keypoints, scores = list(zip(*[
            remove_borders(k, s, self.config['remove_borders'], h*8, w*8)
            for k, s in zip(keypoints, scores)]))

        # Keep the k keypoints with highest score  保留得分最高的 k 个关键点
        if self.config['max_keypoints'] >= 0:
            keypoints, scores = list(zip(*[
                top_k_keypoints(k, s, self.config['max_keypoints'])
                for k, s in zip(keypoints, scores)]))

        # Convert (h, w) to (x, y)  将 (h, w) 转换为 (x, y)
        keypoints = [torch.flip(k, [1]).float() for k in keypoints]

        # Compute the dense descriptors  计算稠密描述符   Descriptor Decoder 特征描述子提取部分
        cDa = self.relu(self.convDa(x))#(1,256,64,64)
        descriptors = self.convDb(cDa)#(1,256,64,64)
        descriptors = torch.nn.functional.normalize(descriptors, p=2, dim=1)#(1,256,64,64)  #对描述子进行 L2 归一化

        # Extract descriptors  提取描述符   从关键点位置对描述子进行插值采样
        descriptors = [sample_descriptors(k[None], d[None], 8)[0]
                       for k, d in zip(keypoints, descriptors)]

        return {
            'keypoints': keypoints,
            'scores': scores,
            'descriptors': descriptors,
        }
  • 11
    点赞
  • 16
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Unknown To Known

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值