解决pytorch旋转变换STN中出现的图像被切割,旋转两次无法还原

 

项目场景:

最近想要在深度学习网络中使用旋转变换,STN旋转变化的几何原理其他博客已经叙述的较为详细,本文不再赘述,本文将从其代码实现方面进行分析,并解决原始代码进行旋转变换存在的图像被切割信息丢失等问题。


问题描述:

STN模型示意图如下,分为三个部分,参数预测,坐标映射,像素采样,其中参数预测是指通过网络来预测仿射变换中的仿射矩阵参数(目前很多STN网络采用的都是这样通过网络来生成仿射矩阵参数,并且没有监督,因此个人认为可解释性不足),坐标映射是指要根据仿射参数,将目标图像中坐标位置映射到原始坐标的位置,即得到目标图像和源图像之间的一个位置对应关系,像素采样则是指根据目标图像和源图像之间位置的对应关系,对源图像进行一个像素采样,这里面可能还包含插值等算法。上面三个部分分别对应Location net,Grid generator,Sampler三个部分。由于本项目只实现旋转变换,并且角度根据其他信息已知,所以旋转变化的代码实现如下(更详细的可以参照链接:https://www.jianshu.com/p/723af68beb2e):

from torch.nn import functional as F
import math

angle = -30*math.pi/180
theta = torch.tensor([
    [math.cos(angle),math.sin(-angle),0],
    [math.sin(angle),math.cos(angle) ,0]
], dtype=torch.float)
grid = F.affine_grid(theta.unsqueeze(0), img_torch.unsqueeze(0).size())
output = F.grid_sample(img_torch.unsqueeze(0), grid)
new_img_torch = output[0]
plt.imshow(new_img_torch.numpy().transpose(1,2,0))
plt.show()

但是如果有人试了一下这个代码可能就会发现一些问题,本文将图片先逆时针旋转20度,然后再顺时针旋转20度,发现有两个问题:1、图片被切割掉了;2、先逆时针旋转后顺时针旋转,发现图像无法回到原始的水平状态


原因分析:

关于这个问题,如果直接使用opencv里面的旋转函数同样会存在上述问题(但matlab的旋转函数是不存在这个问题的),在opencv中已经有人解决了这个问题了,但在pytorch中,我还没有查阅到解决相关问题的方案,因此本文记录一下我解决这个问题的办法。首先从现象上来看,我们可以很轻松的把问题定义到是坐标映射出现了问题,旋转之后的图像肯定是要比原始图像要大的,因此我们直接在坐标映射那里将目标函数的size根据旋转角度,设置成相应大小,发现问题还是没法解决。因此我们还是需要从源代码入手,直接从pytorch中的Function函数里是找不到答案的,因为我们无法直接定位到其源码处,因此我在github上找了一下affine_grid函数的实现。具体链接为:https://github.com/fxia22/stn.pytorch,在这里我们来贴出部分代码:

class AffineGridGenFunction(Function):
    def __init__(self, height, width,lr=1):
        super(AffineGridGenFunction, self).__init__()
        self.lr = lr
        self.height, self.width = height, width
        self.grid = np.zeros( [self.height, self.width, 3], dtype=np.float32)
        self.grid[:, :, 0] = np.expand_dims(np.repeat(np.expand_dims(np.arange(-1, 1, 2.0 / self.height), 0), repeats=self.width, axis=0).T,0)
        self.grid[:,:,1] = np.expand_dims(np.repeat(np.expand_dims(np.arange(-1, 1, 2.0/self.width), 0), repeats = self.height, axis = 0), 0)
        self.grid[:,:,2] = np.ones([self.height, width])
        self.grid = torch.from_numpy(self.grid.astype(np.float32))

这个是网格映射的初始化,什么意思呢,我们输入一个设定的大小,也就是我们想要的目标图像大小,grip与其等大,grip里面装的是啥呢,第0维装的是以图像中心为坐标,图像高度归一化后的坐标信息,第1维代表宽度归一化后的坐标信息,也就是说grid的每一个元素,代表的都是该处像素归一化后的坐标(为啥使用归一化,因为网络里面的数据都是归一化后的数据),这是第一步初始化。第二步呢,就是将目标图像映射到原始图像中去了,代码如图所示:

output1 = torch.bmm(self.batchgrid.view(-1, self.height*self.width, 3), torch.transpose(input1, 1, 2)).view(-1, self.height, self.width, 2)

其实就是上面定义的grid与仿射变换矩阵分别做了一个矩阵乘法,那么得到的output里面每个像素点装的东西就是该点对应源图像上面的坐标位置了。再根据这个映射关系,使用采样函数进行插值采样。

out = F.grid_sample(img, grid_out)

这便是pytorch函数里面坐标映射和像素采样的一个大致过程,整个过程看起来很合情合理,与先前的STN旋转变换几何机理都能一一对应。

但是,如果我们细想,就会发现这个实现其实是存在两个大的问题的,首先我们发现,在网格函数grid进行定义的时候,第一步归一化就是有问题的,什么意思呢,虽然他分别对高度和宽度信息都做了归一化,但是其实二者归一化的基准是不同的,但是我们在与仿射变换相乘的过程其实是并没有区分这二者不同的归一化问题的。其次,我们将目标函数归一化了之后与仿射矩阵相乘,得到源图像的坐标信息,这个坐标其实也是归一化的,但这个归一化是相对谁的?其实还是相对于目标图像的不是,但是我们在把却把相对于目标图像进行归一化之后的坐标映射矩阵给到了基于源图像的采样函数里面,这好吗,这不好。


解决方案:

因此我们针对上述分析,在github源码那里做了两部分改进,一个是在网格初始化那里,由于在与仿射矩阵相乘的时候我们没有将宽度和高度二者不同的归一化进行区分,因此我们需要在定义的时候就保证宽度和高度的归一化是相对于同一个基准的。其次,我们需要在像素采样之前,保证送进去的源图像坐标是相对于源图像归一化后的,而不是相对于目标图像进行归一化后的。因此在上述链接中,大家需要改变上述两个部分的代码,第一部分是网格函数初始化定义处,相应位置改成如下(这里是统一以width为基准进行归一化,其实以哪个为基准无所谓,保证是同一个就行):

        self.grid[:,:,0] = np.expand_dims(np.repeat(np.expand_dims(np.arange(-1, 1, 2.0/self.height), 0), repeats = self.width, axis = 0).T, 0)*height/width
        self.grid[:,:,1] = np.expand_dims(np.repeat(np.expand_dims(np.arange(-1, 1, 2.0/self.width), 0), repeats = self.height, axis = 0), 0)

第二部分,在像素采样函数之前,将网格grid里面的坐标重新进行归一化,(由于该源代码将height放在第0维,所以相应的仿射矩阵也需要做出一定的改变),具体如下:

    theta = torch.tensor([
        [math.sin(-alpha), math.cos(alpha), 0],
        [math.cos(alpha), math.sin(alpha), 0]
    ], dtype=torch.float)
    theta  = theta.unsqueeze(0)
    nW = math.ceil(h*math.fabs(math.sin(alpha))+w*math.cos(alpha))
    nH = math.ceil(h*math.cos(alpha)+w*math.fabs(math.sin(alpha)))

    g = AffineGridGen(nH, nW, aux_loss=True)
    grid_out, aux = g(theta)
    grid_out[:,:,:,0] = grid_out[:,:,:,0]*nW / w
    grid_out[:, :, :, 1] = grid_out[:, :, :, 1] * nW/h
    out = F.grid_sample(img, grid_out)

到这里,pytorch中的完整旋转变换的问题就解决了,很遗憾需要我们自己来定义修改网格定义函数,没有找到能够直接不改源代码的情况下解决上述问题,如果大家有找到不修改源代码的方法,欢迎直接私信或者留言。

 

  • 3
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值