STN是很基础的块,网上大多数文章只是讲了STN是如何进行前向传播的,我相信大多数人好奇的是这样一个无监督的块是如何自我学习的。
一、网络结构简介
1.localisation net
一个多层的CNN网络,输入是图片或者特征图,输出为包含六个参数的 θ \theta θ矩阵。
2.Grid generator
根据 θ \theta θ矩阵和原图U的坐标(系),得到变换后的新图V的坐标(系)
3.Sampler
根据新图V的坐标(系)和原图U的像素值,赋予新图以像素。
二、整体定性
输入一张图片或特征图进入模块中,输出一张仿射变换后的图片或特征图。
三、如何学习权重
这一块属于大佬懒得说,菜鸟搞不懂的一块内容。我属于后者。
首先要知道,这个模块的每一个部分都是可以求梯度的。
假设说这是一个手写数字集,最开始时权重是随机的,也就是说
θ
\theta
θ是随机的。那么这个图像经过这个随机变换后也许是面目全非的,此时原本分类网络的loss就会非常高,第一次学习时分类网络的loss会向梯度的方向进行下降,此时所有的W都会进行改变,包括localisation net中的W,此时就相当于更新了一次
θ
\theta
θ,这样STN就会随着分类网络本身的更新进行更新。