Spatial Transform Network:在网络内利用可学习模块对数据进行空间操作

本文介绍了空间变换网络(Spatial Transformer Networks, STN),这是一种用于卷积神经网络(CNN)的新模块,能够学习并执行空间操作,如缩放、裁剪和平移。STN通过局部网络预测变换参数,然后应用参数化的采样网格进行不同iable图像采样,使得网络能够学习到输入数据的空间不变性。实验表明,STN可以在多种任务中提高模型的表现,例如图像分类和识别。
摘要由CSDN通过智能技术生成

Abstract

  • CNN缺乏对输入数据保持空间不变的能力。
  • 引入了一个新的可学习模块,空间转换器,它允许在网络内对数据进行空间操作。
  • 通过为每个输入样本生成适当的变换来积极地对图像(或特征图)进行空间变换,然后在整个特征图上执行转换(非局部),可以包括缩放、裁剪等。

Spatial Transformers

在这里插入图片描述

  • 图2:空间变压器模块的体系结构。输入特征映射U被传递到一个局部网络,该网络回归变换参数θ。将V上的规则空间网格G转换为采样网格 T θ ( G ) T_θ(G) Tθ(G),该网格应用于U,生成扭曲输出特征图V。

Localisation Network

该部分将Feature转变为 变换矩阵θ,用于下一步Parameterised Sampling Grid。

Parameterised Sampling Grid

在这里插入图片描述

  • 图3:(a)采样网格为规则网格 G = T I ( G ) G=T_I(G) G=TI(G),其中 I I I为恒等变换参数。
    (b)采样网格是用仿射变换 T θ ( G ) T_θ(G) Tθ(G)使规则网格变形的结果。

实现细节:
一般来说,常规的CNN输出网格G = { G i G_i Gi}, G i = ( x i t , y i t ) G_i = (x^t_i , y^t_i) Gi=(xit,yit),形成一个输出特性映射 V ∈ R H ′ × W ′ × C V∈R^{H'×W'×C} VRH×W×C H ′ 和 W ′ H'和W' HW为网格的高度和宽度,C是通道的数量,输入和输出是相同的。
下面描述坐标点的变换:
在这里插入图片描述
其中 ( x i s , y i s ) (x^s_i, y^s_i) (xis,yis)为定义样本点的输入特征映射中的源坐标, A θ A_θ Aθ为仿射变换矩阵;

  • 作者使用高度和宽度标准化坐标,这样,当 − 1 ≤ x i t , y i t ≤ 1 −1≤x^t_i, y^t_i≤1 1xit,yit1时, x i t , y i t x^t_i, y^t_i xit,yit在输出的空间范围内, − 1 ≤ x i s , y i s ≤ 1 −1≤x^s_i, y^s_i≤1 1xis,yis1时, x i s , y i s x^s_i, y^s_i xis,yis在输入的空间范围内(y坐标也类似)。
  • 将裁剪、平移、旋转、缩放和倾斜应用于输入特征图,本地化网络只需要生成6个参数( A θ A_θ Aθ的6个元素)。

Differentiable Image Sampling

经过矩阵运算后,绝大多数坐标点数值为float,但是float无法对应feature map上的int坐标值,因此要通过插值计算出对应的像素值,式子如下:
在这里插入图片描述
其中 Φ x 和 Φ y Φx和Φy ΦxΦy为定义图像插值(如双线性)的通用采样核 k ( ) k() k()的参数

k ( ) k() k()为双线性采样核函数:
在这里插入图片描述
对于双线性采样(5),其偏导数为:
在这里插入图片描述

实验

作者先在MNIST上尝试仿射变换,然后分别在SVHN多数字识别和CUB-200-2011鸟类分类数据集上进行实验。
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

Try

此处尝试论文中的MNIST试验。

import math
import copy

import torch
import torch.nn.functional as F


def vec_to_perpective_matrix(vec):
    # vec rep of the perspective transform has 8 dof; so add 1 for the bottom right of the perspective matrix;
    # note network is initialized to transformer layer bias = [1, 0, 0, 0, 1, 0] so no need to add an identity matrix here
    out = torch.cat((vec, torch.ones((vec.shape[0], 1), dtype=vec.dtype, device=vec.device)), dim=1).reshape(
        vec.shape[0], -1)
    return out.view(-1, 3, 3)


def gen_random_perspective_transform(params):
    """ generate a batch of 3x3 homography matrices by composing rotation, translation, shear, and projection matrices,
    where each samples components from a uniform(-1,1) * multiplicative_factor
    """
    batch_size = params.batch_size

    # debugging
    if params.dict.get('identity_transform_only'):
        return torch.eye(3).repeat(batch_size, 1, 1).to(params.device)

    I = torch.eye(3).repeat(batch_size, 1, 1)
    uniform = torch.distributions.Uniform(-1, 1)
    factor = 0.25
    c = copy.deepcopy

    # rotation component
    a = math.pi / 6 * uniform.sample((batch_size,))
    R = c(I)
    R[:, 0, 0] = torch.cos(a)
    R[:, 0, 1] = - torch.sin(a)
    R[:, 1, 0] = torch.sin(a)
    R[:, 1, 1] = torch.cos(a)
    R.to(params.device)

    # translation component
    tx = factor * uniform.sample((batch_size,))
    ty = factor * uniform.sample((batch_size,))
    T = c(I)
    T[:, 0, 2] = tx
    T[:, 1, 2] = ty
    T.to(params.device)

    # shear component
    sx = factor * uniform.sample((batch_size,))
    sy = factor * uniform.sample((batch_size,))
    A = c(I)
    A[:, 0, 1] = sx
    A[:, 1, 0] = sy
    A.to(params.device)

    # projective component
    px = uniform.sample((batch_size,))
    py = uniform.sample((batch_size,))
    P = c(I)
    P[:, 2, 0] = px
    P[:, 2, 1] = py
    P.to(params.device)

    # compose the homography
    H = R @ T @ P @ A

    return H


def apply_transform_to_batch(im_batch_tensor, transform_tensor):
    """ apply a geometric transform to a batch of image tensors
    args
        im_batch_tensor -- torch float tensor of shape (N, C, H, W)
        transform_tensor -- torch float tensor of shape (1, 3, 3)
    returns
        transformed_batch_tensor -- torch float tensor of shape (N, C, H, W)
    """
    N, C, H, W = im_batch_tensor.shape
    device = im_batch_tensor.device

    # torch.nn.functional.grid_sample takes a grid in [-1,1] and interpolates;
    # construct grid in homogeneous coordinates
    x, y = torch.meshgrid([torch.linspace(-1, 1, H), torch.linspace(-1, 1, W)])
    x, y = x.flatten(), y.flatten()
    xy_hom = torch.stack([x, y, torch.ones(x.shape[0])], dim=0).unsqueeze(0).to(device)

    # tansform the [-1,1] homogeneous coords
    xy_transformed = transform_tensor.matmul(xy_hom)  # 矩阵相乘:(N, 3, 3) matmul (N, 3, H*W) > (N, 3, H*W)
    # convert to inhomogeneous coords -- cf Szeliski eq. 2.21

    grid = xy_transformed[:, :2, :] / (xy_transformed[:, 2, :].unsqueeze(1) + 1e-9)
    grid = grid.permute(0, 2, 1).reshape(-1, H, W, 2)  # (N, H, W, 2); cf torch.functional.grid_sample
    grid = grid.expand(N, *grid.shape[1:])  # expand to minibatch
    print('H',H,'W',W)
    print('grid', grid)
    transformed_batch = F.grid_sample(im_batch_tensor, grid, mode='bilinear', align_corners=True)
    transformed_batch.transpose_(3, 2)

    return transformed_batch


# --------------------
# Test
# --------------------

def test_get_random_perspective_transform():
    import matplotlib
    matplotlib.use('TkAgg')
    import numpy as np
    import matplotlib.pyplot as plt
    from unittest.mock import Mock

    np.random.seed(6)

    im = np.zeros((30, 30))
    im[10:20, 10:20] = 1
    im[20, 20] = 1

    imt = np.array([
        [ 1 , 1 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ],
        [ 1 , 1 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ],
        [ 1 , 1 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ],
        [ 0 , 0 , 0 , 4 , 4 , 6 , 2 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ],
        [ 0 , 0 , 0 ,219,250,253,196,203,202,199,198, 53, 0 , 0 , 0 ],
        [ 0 , 0 , 0 , 1 , 0 , 1 ,62 ,73 ,68 ,62 ,236,104, 0 , 0 , 0 ],
        [ 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ,24 ,253, 3 , 0 , 0 , 0 ],
        [ 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ,10 ,247,61 , 0 , 0 , 0 , 0 ],
        [ 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 8 ,250, 4 , 1 , 0 , 0 , 0 ],
        [ 0 , 0 , 0 , 0 , 0 , 0 , 0 ,21 ,106,185, 0 , 0 , 0 , 0 , 0 ],
        [ 0 , 0 , 0 , 0 , 0 , 0 , 0 ,219,248, 1 , 0 , 0 , 0 , 0 , 0 ],
        [ 0 , 0 , 0 , 0 , 0 , 0 , 38,254, 75, 1 , 0 , 0 , 0 , 0 , 0 ],
        [ 0 , 0 , 0 , 0 , 0 , 1 ,252,62 , 1 , 1 , 0 , 0 , 0 , 0 , 0 ],
        [ 0 , 0 , 0 , 0 , 0 ,121,252,30 , 1 , 1 , 0 , 0 , 0 , 0 , 0 ],
        [ 0 , 0 , 0 , 0 , 0 , 8 , 3 , 0 , 0 , 0 , 0 , 0 , 1 , 1 , 1 ]])

    # get transform
    params = Mock()
    params.batch_size = 1
    params.dict = {'identity_transform_only': False}
    params.device = torch.device('cpu')
    H = gen_random_perspective_transform(params)

    imt = imt[np.newaxis, np.newaxis, ...]
    imt = torch.FloatTensor(imt)
    imt_transformed = apply_transform_to_batch(imt, H)
    fig, axs = plt.subplots(2, 2)

    axs[0, 0].imshow(imt.squeeze().numpy(), cmap='gray')
    axs[0, 1].imshow(imt_transformed.squeeze().numpy(), cmap='gray')

    for ax in plt.gcf().axes:
        ax.axis('off')
    plt.tight_layout()
    plt.show()


if __name__ == '__main__':
    test_get_random_perspective_transform()

在这里插入图片描述

  • 左为原图,右为变换、插值后的图.

小结

  • STN与之前学的注意力机制相似,都是从Features中学习到自适应的矩阵,然后作用到原Features上,只是调整的对象不是像素点的值,而是像素点的坐标值(位置)
  • 对图像(Feature)的裁剪、平移、旋转等都可以是对其坐标值进行矩阵运算,因此论文应用6维矩阵对Feature进行仿射变换。

reference

论文地址
Torch代码
仿射变换 相关矩阵运算 - zhihu

  • 4
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 3
    评论
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值