通俗易懂的Spatial Transformer Networks(STN)(一)

导读

pytorch为了方便实现STN,里面封装了affine_gridgrid_sample两个高级API。对STN不太了解的同学可以参考这篇详细解读Spatial Transformer Networks(STN)

其实STN的作用是想让CNN具备平移、旋转、缩放、剪切不变性,虽然说CNN中的Pooling可以让网络具备一点平移不变性,但这毕竟是隐性的,如果能让网络直接具备这样的能力岂不是更好。

如果对图像处理有了解的同学也许听过仿射变换这个名词,我们只需要通过变换矩阵 θ \theta θ(由6个参数组成)就能实现上面的这些功能,如果对仿射变换不了解的同学可以参考我的这篇一文搞懂仿射变换

STN也是因为受到这个启发而诞生的,那么我们如何将这种能力嵌入到CNN中呢?这便是STN需要解决的问题

STN简介

在这里插入图片描述

上面引用的文章中已经详细介绍了STN网络,我这里总结概括一下

  • Localisation net

Localisation net模块通过CNN提取图像的特征来预测变换矩阵 θ \theta θ

  • Grid generator

Grid generator模块就是利用Localisation net模块回归出来的 θ \theta θ参数来对图片中的位置进行变换,输入图片到输出图片之间的变换,需要特别注意的是这里指的是图片像素所对应的位置

例如:如果此时 θ \theta θ参数功能是实现图片的平移变换(向右平移1,),输入图片上的坐标(1,1),那对应输出图片上的坐标的(2,1),也就是说输入图片上(1,1)对应的像素值等于输出图片上(2,1)对应的像素值。在变换的时候必然会遇到当输入图片的位置变换到输出图片上是如果位置出现小数怎么办?

  • Sampler

Sampler就是用来解决Grid generator模块变换出现小数位置的问题的。针对这种情况,STN采用的是双线性插值(Bilinear Interpolation),下面我们来介绍一下这个算法
在这里插入图片描述
上图中 ( x , y ) (x,y) (x,y)是变换后输出图像上的位置,带下标的坐标位置表示的是与 ( x , y ) (x,y) (x,y)在输入图像对应的四个相邻的坐标。上面的坐标满足下面的关系
x 1 − x 0 = 1 y 1 − y 0 = 1 x_1-x_0 = 1\\ y1-y_0 = 1 x1x0=1y1y0=1
根据双线性插值的原则距离相邻点近的坐标占的比重越大,所以 ( x , y ) (x,y) (x,y)对应的像素值为,我们用 f ( x , y ) f(x,y) f(x,y)表示点 ( x , y ) (x,y) (x,y)所对应的像素值
f ( x , y ) = ( x 1 − x ) ( y 1 − y ) f ( x 0 , y 0 ) + ( x − x 0 ) ( y 1 − y ) f ( x 1 , y 0 ) = + ( x − x 0 ) ( y − y 0 ) f ( x 1 , y 1 ) + ( x 1 − x ) ( y − y 0 ) f ( x 0 , y 1 ) \begin{aligned} f(x,y) &= (x_1-x)(y1-y)f(x_0,y_0)+(x-x_0)(y_1-y)f(x_1,y_0)\\ &=+(x-x_0)(y-y_0)f(x_1,y_1)+(x_1-x)(y-y_0)f(x_0,y_1) \end{aligned} f(x,y)=(x1x)(y1y)f(x0,y0)+(xx0)(y1y)f(x1,y0)=+(xx0)(yy0)f(x1,y1)+(x1x)(yy0)f(x0,y1)

STN层的实现

  • pytorch的实现

通过pytorchaffine_gridgrid_sample可以很容易实现STN的后两个模块

from torchvision import transforms
import torch.nn.functional as F
import torch
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

#读取图片
img = Image.open("img/test.jpg")
#将图片转换为torch tensor
img_tensor = transforms.ToTensor()(img)

#定义平移变换矩阵
#0.1表示将图片向左平移图片宽的百分比
#0.2表示将图片向上平移图片高的百分比
theta = torch.tensor([[1,0,0.1],[0,1,0.2]],
                     dtype=torch.float)
#根据变换矩阵来计算变换后图片的对应位置
grid = F.affine_grid(theta.unsqueeze(0),
               img_tensor.unsqueeze(0).size(),align_corners=True)
#默认使用双向性插值,可以通过mode参数设置
output = F.grid_sample(img_tensor.unsqueeze(0),
			   grid,align_corners=True)

plt.figure()
plt.subplot(1,2,1)
plt.imshow(np.array(img))
plt.title("original image")

plt.subplot(1,2,2)
plt.imshow(output[0].numpy().transpose(1,2,0))
plt.title("stn transform image")

plt.show()

在这里插入图片描述

  • numpy的实现

我们通过numpy来实现STN的后两个模块,来帮助大家更好的理解STN

class Grid_sample(object):
    def affine_grid(self,theta,img_size):
        if len(img_size) != 2:
            assert("img_size size must is 2")
        num_batch = np.shape(theta)[0]
        img_w,img_h = img_size
        #将图片位置归一化到(-1,1)
        x = np.linspace(-1.0,1.0,img_w)
        y = np.linspace(-1.0,1.0,img_h)

        #组合x和y获取到图片的位置坐标
        x_t,y_t = np.meshgrid(x,y)
        x_t_flat = np.reshape(x_t,[-1])
        y_t_flat = np.reshape(y_t,[-1])

        #创建一个图片的位置数组
        ones = np.ones_like(x_t_flat)
        sampling_grid = np.stack([x_t_flat,y_t_flat,ones])
        sampling_grid = np.expand_dims(sampling_grid,axis=0)
        sampling_grid = np.tile(sampling_grid,
                                np.stack([num_batch,1,1]))

        #计算变换后的图片位置
        batch_grids = np.matmul(theta,sampling_grid)
        batch_grids = np.reshape(batch_grids,
                                 [num_batch,2,img_h,img_w])

        return batch_grids


    def bilinear_sampler(self,img,batch_grids):
        if (batch_grids.shape) != 4:
            assert("batch_grids shape is must equal 4")
        #获取变换后图片位置的x和y轴的坐标位置
        x = batch_grids[:, 0, :, :]
        y = batch_grids[:, 1, :, :]

        img_w,img_h = img.shape[:2]
        max_x = img_w - 1
        max_y = img_h - 1

        #将变换后的坐标位置固定到(0,w/h-1)
        x = 0.5 * ((x+1.0)*(max_x-1))
        y = 0.5 * ((y+1.0)*(max_y-1))

        #将坐标位置取整,便于从输入图片中获取位置对应的像素值
        x0 = np.floor(x).astype(np.int)
        x1 = x0 + 1
        y0 = np.floor(y).astype(np.int)
        y1 = y0 + 1

        #防止坐标越界
        x0 = np.clip(x0,0,max_x)
        x1 = np.clip(x1,0,max_x)
        y0 = np.clip(y0,0,max_y)
        y1 = np.clip(y1,0,max_y)

        #根据坐标位置,取像素值
        Ia = img[y0,x0,:]
        Ib = img[y1,x0,:]
        Ic = img[y0,x1,:]
        Id = img[y1,x1,:]

        wa = np.expand_dims((x1-x)*(y1-y),axis=3)
        wb = np.expand_dims((x1-x)*(y-y0),axis=3)
        wc = np.expand_dims((x-x0)*(y1-y),axis=3)
        wd = np.expand_dims((x-x0)*(y-y0),axis=3)

        #利用双线性插值计算变换后的像素值
        out = wa*Ia + wb*Ib + wc*Ic + wd*Id

        return out


grid_sampler = Grid_sample()
img = np.array(Image.open("img/test.jpg"))
img_h,img_w = img.shape[:2]
theta = np.array([[[1, 0, 0.1], [0, 1, 0.2]]],dtype=np.float)
theta = np.expand_dims(theta,axis=0)

batch_grids = grid_sampler.affine_grid(theta,(img_w,img_h))
out = grid_sampler.bilinear_sampler(img,batch_grids)

plt.figure()
plt.subplot(1, 2, 1)
plt.imshow(np.array(img))
plt.title("original image")

plt.subplot(1, 2, 2)
plt.imshow(out[0].astype(np.uint8))
plt.title("stn transform image")

plt.show()

在这里插入图片描述
下一篇文章我们介绍如何将STN模块插入到CNN中

  • 11
    点赞
  • 38
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 3
    评论
空间变换网络(Spatial Transformer NetworksSTN)是一种神经网络结构,用于改善卷积神经网络(CNN)的空间不变性。STN可以对经过平移、旋转、缩放和裁剪等操作的图像进行变换,使得网络在变换后的图像上得到与原始图像相同的检测结果,从而提高分类的准确性。STN由三个主要部分组成:局部化网络(Localisation Network)、参数化采样网格(Parameterised Sampling Grid)和可微分图像采样(Differentiable Image Sampling)。 局部化网络是STN的关键组件,它负责从输入图像中学习如何进行变换。局部化网络通常由卷积和全连接层组成,用于估计变换参数。参数化采样网格是一个由坐标映射函数生成的二维网格,它用于定义变换后每个像素在原始图像中的位置。可微分图像采样则是通过应用参数化采样网格来执行图像的变换,并在变换后的图像上进行采样。 使用STN的主要优点是它能够在不改变网络结构的情况下增加空间不变性。这使得网络能够处理更广泛的变换,包括平移、旋转、缩放和裁剪等。通过引入STN层,CNN可以学习到更鲁棒的特征表示,从而提高分类准确性。 关于STN的代码实现,您可以在GitHub上找到一个示例实现。这个实现使用TensorFlow框架,提供了STN网络的完整代码和示例。您可以通过查看该代码来了解如何在您的项目中使用STN。 综上所述,spatial transformer networks(空间变换网络)是一种神经网络结构,用于增加CNN的空间不变性。它包括局部化网络、参数化采样网格和可微分图像采样三个部分。通过引入STN层,CNN可以学习到更鲁棒的特征表示,从而提高分类准确性。在GitHub上有一个使用TensorFlow实现的STN示例代码供参考。
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

修炼之路

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值