pytorch中F.grid_sample函数实现warp功能记录

一.

光流算法中最重要的假设是亮度一致性,即:

                                                                                                I1(x, y) = I2(x+u, y+v)

x = (x, y), u = (u, v),  w(x,u) = x + u, 则I2(x+u, y+v)的可以记作I2(w(x, y)).

w即是光流算法中的warp函数,在pytorch中可以借助torch.nn.functional.grid_sample实现!

 

对于output中的每一个像素(x, y),它会根据流值在input中找到对应的像素点(x+u, y+v),并赋予自己对应点的像素值,这便完成了warp操作。但这个对应点的坐标不一定是整数值,因此要用到插值或者使用邻近值,也就是选项mode的作用。

那么如何找到对应像素点呢?关键的过程在于grid,若grid(x,y)的两个通道值为( m, n ),则表明output(x,y)的对应点在input的(m, n)处。但这里一般会将m和n的取值范围归一化到[-1, 1]之间,[-1, -1]表示input左上角的像素的坐标,[1, 1]表示input右下角的像素的坐标,对于超出这个范围的坐标,函数将会根据参数padding_mode的设定进行不同的处理。

因此,首先指定一个[-1, 1]的网格G,网格的间距为2/width 和 2/height, 此时将G传入grid, output会在原地寻找对应点.但我们还有光流值(u,v)没有利用,如何利用还需要进一步分析:

(1)之前的索引值为[0, width-1]和[0, height-1],   (x , y)对应的坐标应为 (x + u, y + v)。

(2)现在的索引值为[-1, 1] 和 [-1, 1] ,   (x, y)对应的坐标应该为(x + u_scale, v + v_scale),该坐标值同样位于[-1, 1]之间,根据该值在input中寻找像素点。

由于像素点的间距由1变化到2/(width-1)和2/(height-1),因此也对光流值进行缩放使其变为

                                                                               u_scale = ( u * (2/(width-1)),  v * ( 2/(height-1)) )

然后在将G与u_scale相加,便可以指定output(x,y)在input中的对应点input(x+u, y+v)位置.

 

二. 对I2图像的warp操作在backwarp函数中实现:

输入:

1. tenInput为I2图像,shape = (batchsize, 3,  H, W) , 3个通道分别代表RGB。

2. tenFlow为估计的光流值, shape = (batchsize, 2, H , W),   2个通道分别为u, v分量。

输出:

对I2进行warp操作后所对应的图像I2(w(x, y)),如果光流值完全准确,该图形应该与I1相同。

backwarp_tenGrid = {}
def backwarp(tenInput, tenFlow):
	if str(tenFlow.size()) not in backwarp_tenGrid:
		tenHorizontal = torch.linspace(-1.0, 1.0, tenFlow.shape[3]).view(1, 1, 1, tenFlow.shape[3]).expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1)
		tenVertical = torch.linspace(-1.0, 1.0, tenFlow.shape[2]).view(1, 1, tenFlow.shape[2], 1).expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3])

		backwarp_tenGrid[str(tenFlow.size())] = torch.cat([ tenHorizontal, tenVertical ], 1).cuda()
	# end

	tenFlow = torch.cat([ tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0), tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0) ], 1)

	return torch.nn.functional.grid_sample(input=tenInput, grid=(backwarp_tenGrid[str(tenFlow.size())] + tenFlow).permute(0, 2, 3, 1), mode='bilinear', padding_mode='border', align_corners=True)
# end

 1.首先,要创建一个[-1, 1]的网格backwarp_tenGrid, 其包含两个通道:

(1).tenHorizontal用于指定横坐标x的位置.

(2).tenVertical用于指定纵坐标y的位置.

 2. 对tenFlow进行缩放,就是前面所介绍的:

                                                                               u_scale = ( u * (2/(width-1)),  v * ( 2/(height-1)) )

3.  grid = (backwarp_tenGrid + u_scale)并通过tensor.permute变换维度.  premute的演示如下:

评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值