本文参考:
https://www.jianshu.com/p/e3f386771c51
仿射变换(AffineTransform)与仿射矩阵_TracelessLe的博客-CSDN博客_仿射矩阵
Pytorch中的仿射变换(affine_grid)_张博208的博客-CSDN博客_affine_grid
通俗易懂的Spatial Transformer Networks(STN)(一)_修炼之路的博客-CSDN博客
通俗易懂的Spatial Transformer Networks(STN)(二)_修炼之路的博客-CSDN博客
详细解读Spatial Transformer Networks(STN)-一篇文章让你完全理解STN了_黄小猿的博客-CSDN博客_stn算法 (最重要)
一、仿射变化
1、实质:
仿射变化 = 线性变化 + 一个平移
2、变换公式:
进一步转化为:
或:
所以6个参数决定了一个仿射变化
3、几种变换矩阵
(1)扩大/缩放
(2)旋转
4、仿射变化代码
Opencv和pytorch都提供了仿射变化的函数。
Pytorch针对仿射变化提供了2个函数:
(1)创建grid,该grid为通过仿射后图片的位置坐标信息
grid = torch.nn.functional.affine_grid(theta, size)
(2)grid_sample重采样,根据输入图片变换后图片位置填充像素值
outputs = torch.nn.functional.grid_sample(inputs, grid, mode='bilinear')
mode=’bilinear’的原因是:
首先仿射变化的本质为采样,指定采样前和采样后的位置映射信息,然后把像素值复制过去,当放大后就出现部分位置空缺的情况,此时可用双线性插值填充空缺位置的像素值。
-》
二、STN网络原理
1、概述
STN:spatial transformer net。当输入图片通过STN模块之后获得变换后的图片,然后再将变换后的图片输入到CNN网络中,通过损失函数计算loss,然后计算梯度更新θ参数,最终STN模块会学习到如何矫正图片。
它分为三个部分。
(1)Localisation Net:
通过CNN提取的图像特征来预测变换矩阵θ。
即:根据图片特征决定变换矩阵,所以变换矩阵每张图片都不一样,也不是根据channel来决定变换矩阵。
(2)Grid generator:
根据θ生成变换前后的位置变换映射关系。
(3)Sampler
根据位置变换关系进行像素值采样,并通过双线性插值(Bilinear Interpolation)解决Grid generator模块出现小数位置的问题。
2、Localisation Net实现参数选取
各种仿射变化,都可以通过仿射矩阵实现,只需要六个参数[2*3]控制就可以了。所以我们可以把feature map作为输入,过连续若干层计算(如卷积、FC等),回归出参数θ,在我们的例子中就是一个[2,3]大小的6维仿射变换参数,用于下一步计算。
3、Grid Generator实现像素点坐标的对应关系
缩放旋转的本质,其实就是在原样本上采样,拿到对应的像素点,通俗点说,就是输出的图片(i, j)的位置上,要对应输入图片的哪个位置。
4、Sampler实现坐标求解的可微性
假如θ中都为整数,则源像素点位置对应目标像素点的位置也是整数。假如θ中有小数,但是没有元素的下标索引是小数。用四舍五入显然不能进行梯度下降来回传梯度的。
因为,梯度下降是一步一步调整的,而且调整的数值都比较小,哪怕权值参数有小范围的变化,最后的输出也会有小范围的变化。此时做如下改动:
上述公式首先根据小数的信息确定了本层网络的索引针对上层网络相关索引的组合信息,然后根据小数的值计算距离确定本层的最后像素点值。这样权值都是与结果对应的距离相关的,如果目标图片发生了小范围的变化,这个式子也是可以捕捉到这样的变化的,这样就能用梯度下降法来优化了。
三、STN代码示例
具体代码参考:https://blog.csdn.net/sinat_29957455/article/details/112756934
θ参数是根据特征进行调整的
# 回归theta参数
self.fc_loc = nn.Sequential(
nn.Linear(10 * 3 * 3, 32),
nn.ReLU(True),
nn.Linear(32, 2 * 3)
)
最后是根据图像生成的32位特征得到矫正参数。