STN(Spatial Transformer Networks)
STN架构
Localisation net(本地网络,参数预测):通过CNN提取图像的特征来预测变换矩阵 θ θ θ 。
假设 θ = [ θ 11 , θ 21 , θ 12 θ 22 , θ 13 , θ 23 ] θ=\left[\begin{matrix}θ_{11},θ_{21},θ_{12}\\θ_{22},θ_{13},θ_{23}\end{matrix}\right] θ=[θ11,θ21,θ12θ22,θ13,θ23]
Grid generator(网格生成器,坐标映射):已知V坐标,利用Localisation net
获得的
θ
θ
θ对图片中像素对应位置进行变换,获得U对应的坐标。(编程时,要预先定义V的shape,pytoch会根据shape的大小自动初始化一个坐标)
假设 V = [ x t y t 1 ] V = \left[ \begin{matrix}x^t\\y^t\\1\end{matrix}\right] V=⎣⎡xtyt1⎦⎤, U = [ x y ] U = \left[ \begin{matrix}x\\y\end{matrix}\right] U=[xy]
对应关系:
U = [ x y ] = θ V = [ θ 11 , θ 21 , θ 12 θ 22 , θ 13 , θ 23 ] [ x t y t 1 ] U = \left[ \begin{matrix}x\\y\end{matrix}\right]=θV=\left[\begin{matrix}θ_{11},θ_{21},θ_{12}\\θ_{22},θ_{13},θ_{23}\end{matrix}\right] \left[ \begin{matrix}x^t\\y^t\\1\end{matrix}\right] U=[xy]=θV=[θ11,θ21,θ12θ22,θ13,θ23]⎣⎡xtyt1⎦⎤
这里采用向后映射。
(向前映射:已知原图像上的坐标,并且已知原图像坐标到目标图像坐标的映射关系,因此可以求得原图像上一点经过映射后在目标图像上的位置。
向后映射:已知目标图像上的坐标,并且已知目标图像坐标到原图像坐标的映射关系,因此可以求得目标图像上一点经过映射后在原图像上的位置。)
Sampler(采样器,像素采集):解决Grid generator
模块变换出现小数位置的问题,利用采样网络和输入特征图生成变换后的结果。(通常情况下一个整数的坐标( x , y ) 经过映射后往往都位于非整数位置,此时就要采用插值方法进行采样。)(使用双线性插值进行采样)
pytorch封装affine_grid
和grid_sample
两个API用于实现STN
input (Tensor): input of shape
theta (Tensor): input batch of affine matrices
size (torch.Size): the target output image size
affine_grid
:根据变换矩阵来计算变换后图片的对应位置
grid = torch.nn.functional.affine_grid(theta, size)
grid_sample
:#默认使用双向性插值,可以通过mode参数设置
outputs = torch.nn.functional.grid_sample(input, grid, mode='bilinear')
#读取图片
img = Image.open("img/test.jpg")
#将图片转换为torch tensor
img_tensor = transforms.ToTensor()(img)
#定义平移变换矩阵
#0.1表示将图片向左平移图片宽的百分比
#0.2表示将图片向上平移图片高的百分比
theta = torch.tensor([[1,0,0.1],[0,1,0.2]],
dtype=torch.float)
#根据变换矩阵来计算变换后图片的对应位置
grid = F.affine_grid(theta.unsqueeze(0),
img_tensor.unsqueeze(0).size())
#默认使用双向性插值,可以通过mode参数设置
output = F.grid_sample(img_tensor.unsqueeze(0),
grid)