mmcv中有对其的实现,可以实现加速:
def bilinear_grid_sample(im, grid, align_corners=False):
n, c, h, w = im.shape
gn, gh, gw, _ = grid.shape
assert n == gn
x = grid[:, :, :, 0]
y = grid[:, :, :, 1]
if align_corners:
x = ((x + 1) / 2) * (w - 1)
y = ((y + 1) / 2) * (h - 1)
else:
x = ((x + 1) * w - 1) / 2
y = ((y + 1) * h - 1) / 2
x = x.view(n, -1)
y = y.view(n, -1)
x0 = torch.floor(x).long()
y0 = torch.floor(y).long()
x1 = x0 + 1
y1 = y0 + 1
wa = ((x1 - x) * (y1 - y)).unsqueeze(1)
wb = ((x1 - x) * (y - y0)).unsqueeze(1)
wc = ((x - x0) * (y1 - y)).unsqueeze(1)
wd = ((x - x0) * (y - y0)).unsqueeze(1)
# Apply default for grid_sample function zero padding
im_padded = F.pad(im, pad=[1, 1, 1, 1], mode='constant', value=0)
padded_h = h + 2
padded_w = w + 2
# save points positions after padding
x0, x1, y0, y1 = x0 + 1, x1 + 1, y0 + 1, y1 + 1
# Clip coordinates to padded image size
# x0 = torch.where(x0 < 0, torch.tensor(0), x0)
# x0 = torch.where(x0 > padded_w - 1, torch.tensor(padded_w - 1), x0)
# x1 = torch.where(x1 < 0, torch.tensor(0), x1)
# x1 = torch.where(x1 > padded_w - 1, torch.tensor(padded_w - 1), x1)
# y0 = torch.where(y0 < 0, torch.tensor(0), y0)
# y0 = torch.where(y0 > padded_h - 1, torch.tensor(padded_h - 1), y0)
# y1 = torch.where(y1 < 0, torch.tensor(0), y1)
# y1 = torch.where(y1 > padded_h - 1, torch.tensor(padded_h - 1), y1)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
x0 = torch.where(x0 < 0, torch.tensor(0).to(device), x0)
x0 = torch.where(x0 > padded_w - 1, torch.tensor(padded_w - 1).to(device), x0)
x1 = torch.where(x1 < 0, torch.tensor(0).to(device), x1)
x1 = torch.where(x1 > padded_w - 1, torch.tensor(padded_w - 1).to(device), x1)
y0 = torch.where(y0 < 0, torch.tensor(0).to(device), y0)
y0 = torch.where(y0 > padded_h - 1, torch.tensor(padded_h - 1).to(device), y0)
y1 = torch.where(y1 < 0, torch.tensor(0).to(device), y1)
y1 = torch.where(y1 > padded_h - 1, torch.tensor(padded_h - 1).to(device), y1)
im_padded = im_padded.view(n, c, -1)
x0_y0 = (x0 + y0 * padded_w).unsqueeze(1).expand(-1, c, -1)
x0_y1 = (x0 + y1 * padded_w).unsqueeze(1).expand(-1, c, -1)
x1_y0 = (x1 + y0 * padded_w).unsqueeze(1).expand(-1, c, -1)
x1_y1 = (x1 + y1 * padded_w).unsqueeze(1).expand(-1, c, -1)
Ia = torch.gather(im_padded, 2, x0_y0)
Ib = torch.gather(im_padded, 2, x0_y1)
Ic = torch.gather(im_padded, 2, x1_y0)
Id = torch.gather(im_padded, 2, x1_y1)
return (Ia * wa + Ib * wb + Ic * wc + Id * wd).reshape(n, c, gh, gw)
CRE中也是用bilinear_grid_sample替代F.grid_sample的实现:
def bilinear_grid_sample(im, grid, align_corners=False):
n, c, h, w = im.shape
gn, gh, gw, _ = grid.shape
assert n == gn
x = grid[:, :, :, 0]
y = grid[:, :, :, 1]
if align_corners:
x = ((x + 1) / 2) * (w - 1)
y = ((y + 1) / 2) * (h - 1)
else:
x = ((x + 1) * w - 1) / 2
y = ((y + 1) * h - 1) / 2
x = x.view(n, -1)
y = y.view(n, -1)
x0 = torch.floor(x).long()
y0 = torch.floor(y).long()
x1 = x0 + 1
y1 = y0 + 1
wa = ((x1 - x) * (y1 - y)).unsqueeze(1)
wb = ((x1 - x) * (y - y0)).unsqueeze(1)
wc = ((x - x0) * (y1 - y)).unsqueeze(1)
wd = ((x - x0) * (y - y0)).unsqueeze(1)
# Apply default for grid_sample function zero padding
im_padded = torch.nn.functional.pad(im, pad=[1, 1, 1, 1], mode='constant', value=0)
padded_h = h + 2
padded_w = w + 2
# save points positions after padding
x0, x1, y0, y1 = x0 + 1, x1 + 1, y0 + 1, y1 + 1
# Clip coordinates to padded image size
x0 = torch.where(x0 < 0, torch.tensor(0, device=im.device), x0)
x0 = torch.where(x0 > padded_w - 1, torch.tensor(padded_w - 1, device=im.device), x0)
x1 = torch.where(x1 < 0, torch.tensor(0, device=im.device), x1)
x1 = torch.where(x1 > padded_w - 1, torch.tensor(padded_w - 1, device=im.device), x1)
y0 = torch.where(y0 < 0, torch.tensor(0, device=im.device), y0)
y0 = torch.where(y0 > padded_h - 1, torch.tensor(padded_h - 1, device=im.device), y0)
y1 = torch.where(y1 < 0, torch.tensor(0, device=im.device), y1)
y1 = torch.where(y1 > padded_h - 1, torch.tensor(padded_h - 1, device=im.device), y1)
im_padded = im_padded.view(n, c, -1)
x0_y0 = (x0 + y0 * padded_w).unsqueeze(1).expand(-1, c, -1)
x0_y1 = (x0 + y1 * padded_w).unsqueeze(1).expand(-1, c, -1)
x1_y0 = (x1 + y0 * padded_w).unsqueeze(1).expand(-1, c, -1)
x1_y1 = (x1 + y1 * padded_w).unsqueeze(1).expand(-1, c, -1)
Ia = torch.gather(im_padded, 2, x0_y0)
Ib = torch.gather(im_padded, 2, x0_y1)
Ic = torch.gather(im_padded, 2, x1_y0)
Id = torch.gather(im_padded, 2, x1_y1)
return (Ia * wa + Ib * wb + Ic * wc + Id * wd).reshape(n, c, gh, gw)
paddle实现:
def paddle_bilinear_grid_sample(self, im, grid, align_corners=False):
x = grid[:, :, :, 0]
y = grid[:, :, :, 1]
if align_corners:
x = ((x + 1) / 2) * (w - 1)
y = ((y + 1) / 2) * (h - 1)
else:
x = ((x + 1) * w - 1) / 2
y = ((y + 1) * h - 1) / 2
x = paddle.reshape(x, [n, -1])
y = paddle.reshape(y, [n, -1])
x0 = paddle.floor(x).astype('float32')
y0 = paddle.floor(y).astype('float32')
x1 = x0 + 1
y1 = y0 + 1
x1_cast = x1.astype(grid.dtype)
x0_cast = x0.astype(grid.dtype)
y1_cast = y1.astype(grid.dtype)
y0_cast = y0.astype(grid.dtype)
wa = paddle.unsqueeze(((x1_cast - x) * (y1_cast - y)), 1)
wb = paddle.unsqueeze(((x1_cast - x) * (y - y0_cast)), 1)
wc = paddle.unsqueeze(((x - x0_cast) * (y1_cast - y)), 1)
wd = paddle.unsqueeze(((x - x0_cast) * (y - y0_cast)), 1)
# Apply default for grid_sample function zero padding
im_padded = paddle.nn.functional.pad(im,
pad=[1, 1, 1, 1],
mode='constant',
value=0)
if im_padded.dtype != im.dtype:
im_padded = paddle.cast(im_padded, im.dtype)
padded_h = h + 2
padded_w = w + 2
# save points positions after padding
x0, x1, y0, y1 = x0 + 1, x1 + 1, y0 + 1, y1 + 1
# Clip coordinates to padded image size
tensor_zero = paddle.full(shape=[1], dtype='float32', fill_value=0.0)
tensor_padded_w = paddle.full(
shape=[1], dtype='float32', fill_value=padded_w - 1)
tensor_padded_h = paddle.full(
shape=[1], dtype='float32', fill_value=padded_h - 1)
x0 = paddle.where(x0 < 0, tensor_zero, x0)
x0 = paddle.where(x0 > padded_w - 1, tensor_padded_w, x0)
x1 = paddle.where(x1 < 0, tensor_zero, x1)
x1 = paddle.where(x1 > padded_w - 1, tensor_padded_w, x1)
y0 = paddle.where(y0 < 0, tensor_zero, y0)
y0 = paddle.where(y0 > padded_h - 1, tensor_padded_h, y0)
y1 = paddle.where(y1 < 0, tensor_zero, y1)
y1 = paddle.where(y1 > padded_h - 1, tensor_padded_h, y1)
im_padded = paddle.reshape(im_padded, [n, c, -1])
x0_y0 = paddle.expand(
paddle.unsqueeze((x0 + y0 * padded_w), 1), [-1, c, -1]).astype('int64')
x0_y1 = paddle.expand(
paddle.unsqueeze((x0 + y1 * padded_w), 1), [-1, c, -1]).astype('int64')
x1_y0 = paddle.expand(
paddle.unsqueeze((x1 + y0 * padded_w), 1), [-1, c, -1]).astype('int64')
x1_y1 = paddle.expand(
paddle.unsqueeze((x1 + y1 * padded_w), 1), [-1, c, -1]).astype('int64')
# Ia = self.gather(im_padded, 2, x0_y0)
# Ib = self.gather(im_padded, 2, x0_y1)
# Ic = self.gather(im_padded, 2, x1_y0)
# Id = self.gather(im_padded, 2, x1_y1)
Ia = paddle.take_along_axis(im_padded, x0_y0, 2).astype('float32')
Ib = paddle.take_along_axis(im_padded, x0_y1, 2).astype('float32')
Ic = paddle.take_along_axis(im_padded, x1_y0, 2).astype('float32')
Id = paddle.take_along_axis(im_padded, x1_y1, 2).astype('float32')
return paddle.reshape((Ia * wa + Ib * wb + Ic * wc + Id * wd),
[n, c, gh, gw])
def gather(self, x, dim, index):
# index_shape = index.shape
index_shape = paddle.shape(index)
x_shape = paddle.shape(x)
index_flatten = index.flatten()
if dim < 0:
dim = len(x.shape) + dim
nd_index = []
for k in range(len(x.shape)):
if k == dim:
nd_index.append(index_flatten)
else:
reshape_shape = [1] * len(x.shape)
x_shape_k = x_shape[k]
# x_shape_k = x.shape[k]
reshape_shape[k] = x_shape[k]
x_arange = paddle.arange(x_shape_k, dtype=index.dtype)
x_arange = x_arange.reshape(reshape_shape)
dim_index = paddle.expand(x_arange, index_shape).flatten()
nd_index.append(dim_index)
ind2 = paddle.transpose(paddle.stack(nd_index), [1, 0]).astype("int64")
paddle_out = paddle.gather_nd(x, ind2).reshape(index_shape)
return paddle_out