TensorRT——grid_sample转换

mmcv中有对其的实现,可以实现加速:

def bilinear_grid_sample(im, grid, align_corners=False):

    n, c, h, w = im.shape

    gn, gh, gw, _ = grid.shape

    assert n == gn



    x = grid[:, :, :, 0]

    y = grid[:, :, :, 1]



    if align_corners:

        x = ((x + 1) / 2) * (w - 1)

        y = ((y + 1) / 2) * (h - 1)

    else:

        x = ((x + 1) * w - 1) / 2

        y = ((y + 1) * h - 1) / 2



    x = x.view(n, -1)

    y = y.view(n, -1)



    x0 = torch.floor(x).long()

    y0 = torch.floor(y).long()

    x1 = x0 + 1

    y1 = y0 + 1



    wa = ((x1 - x) * (y1 - y)).unsqueeze(1)

    wb = ((x1 - x) * (y - y0)).unsqueeze(1)

    wc = ((x - x0) * (y1 - y)).unsqueeze(1)

    wd = ((x - x0) * (y - y0)).unsqueeze(1)



    # Apply default for grid_sample function zero padding

    im_padded = F.pad(im, pad=[1, 1, 1, 1], mode='constant', value=0)

    padded_h = h + 2

    padded_w = w + 2

    # save points positions after padding

    x0, x1, y0, y1 = x0 + 1, x1 + 1, y0 + 1, y1 + 1



    # Clip coordinates to padded image size

    # x0 = torch.where(x0 < 0, torch.tensor(0), x0)

    # x0 = torch.where(x0 > padded_w - 1, torch.tensor(padded_w - 1), x0)

    # x1 = torch.where(x1 < 0, torch.tensor(0), x1)

    # x1 = torch.where(x1 > padded_w - 1, torch.tensor(padded_w - 1), x1)

    # y0 = torch.where(y0 < 0, torch.tensor(0), y0)

    # y0 = torch.where(y0 > padded_h - 1, torch.tensor(padded_h - 1), y0)

    # y1 = torch.where(y1 < 0, torch.tensor(0), y1)

    # y1 = torch.where(y1 > padded_h - 1, torch.tensor(padded_h - 1), y1)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    x0 = torch.where(x0 < 0, torch.tensor(0).to(device), x0)

    x0 = torch.where(x0 > padded_w - 1, torch.tensor(padded_w - 1).to(device), x0)

    x1 = torch.where(x1 < 0, torch.tensor(0).to(device), x1)

    x1 = torch.where(x1 > padded_w - 1, torch.tensor(padded_w - 1).to(device), x1)

    y0 = torch.where(y0 < 0, torch.tensor(0).to(device), y0)

    y0 = torch.where(y0 > padded_h - 1, torch.tensor(padded_h - 1).to(device), y0)

    y1 = torch.where(y1 < 0, torch.tensor(0).to(device), y1)

    y1 = torch.where(y1 > padded_h - 1, torch.tensor(padded_h - 1).to(device), y1)



    im_padded = im_padded.view(n, c, -1)



    x0_y0 = (x0 + y0 * padded_w).unsqueeze(1).expand(-1, c, -1)

    x0_y1 = (x0 + y1 * padded_w).unsqueeze(1).expand(-1, c, -1)

    x1_y0 = (x1 + y0 * padded_w).unsqueeze(1).expand(-1, c, -1)

    x1_y1 = (x1 + y1 * padded_w).unsqueeze(1).expand(-1, c, -1)



    Ia = torch.gather(im_padded, 2, x0_y0)

    Ib = torch.gather(im_padded, 2, x0_y1)

    Ic = torch.gather(im_padded, 2, x1_y0)

    Id = torch.gather(im_padded, 2, x1_y1)



    return (Ia * wa + Ib * wb + Ic * wc + Id * wd).reshape(n, c, gh, gw)

CRE中也是用bilinear_grid_sample替代F.grid_sample的实现:

def bilinear_grid_sample(im, grid, align_corners=False):

    n, c, h, w = im.shape
    gn, gh, gw, _ = grid.shape
    assert n == gn

    x = grid[:, :, :, 0]
    y = grid[:, :, :, 1]

    if align_corners:
        x = ((x + 1) / 2) * (w - 1)
        y = ((y + 1) / 2) * (h - 1)
    else:
        x = ((x + 1) * w - 1) / 2
        y = ((y + 1) * h - 1) / 2

    x = x.view(n, -1)
    y = y.view(n, -1)

    x0 = torch.floor(x).long()
    y0 = torch.floor(y).long()
    x1 = x0 + 1
    y1 = y0 + 1

    wa = ((x1 - x) * (y1 - y)).unsqueeze(1)
    wb = ((x1 - x) * (y - y0)).unsqueeze(1)
    wc = ((x - x0) * (y1 - y)).unsqueeze(1)
    wd = ((x - x0) * (y - y0)).unsqueeze(1)

    # Apply default for grid_sample function zero padding
    im_padded = torch.nn.functional.pad(im, pad=[1, 1, 1, 1], mode='constant', value=0)
    padded_h = h + 2
    padded_w = w + 2
    # save points positions after padding
    x0, x1, y0, y1 = x0 + 1, x1 + 1, y0 + 1, y1 + 1

    # Clip coordinates to padded image size
    x0 = torch.where(x0 < 0, torch.tensor(0, device=im.device), x0)
    x0 = torch.where(x0 > padded_w - 1, torch.tensor(padded_w - 1, device=im.device), x0)
    x1 = torch.where(x1 < 0, torch.tensor(0, device=im.device), x1)
    x1 = torch.where(x1 > padded_w - 1, torch.tensor(padded_w - 1, device=im.device), x1)
    y0 = torch.where(y0 < 0, torch.tensor(0, device=im.device), y0)
    y0 = torch.where(y0 > padded_h - 1, torch.tensor(padded_h - 1, device=im.device), y0)
    y1 = torch.where(y1 < 0, torch.tensor(0, device=im.device), y1)
    y1 = torch.where(y1 > padded_h - 1, torch.tensor(padded_h - 1, device=im.device), y1)

    im_padded = im_padded.view(n, c, -1)

    x0_y0 = (x0 + y0 * padded_w).unsqueeze(1).expand(-1, c, -1)
    x0_y1 = (x0 + y1 * padded_w).unsqueeze(1).expand(-1, c, -1)
    x1_y0 = (x1 + y0 * padded_w).unsqueeze(1).expand(-1, c, -1)
    x1_y1 = (x1 + y1 * padded_w).unsqueeze(1).expand(-1, c, -1)

    Ia = torch.gather(im_padded, 2, x0_y0)
    Ib = torch.gather(im_padded, 2, x0_y1)
    Ic = torch.gather(im_padded, 2, x1_y0)
    Id = torch.gather(im_padded, 2, x1_y1)

    return (Ia * wa + Ib * wb + Ic * wc + Id * wd).reshape(n, c, gh, gw)

paddle实现:

def paddle_bilinear_grid_sample(self, im, grid, align_corners=False):

        x = grid[:, :, :, 0]
        y = grid[:, :, :, 1]

        if align_corners:
            x = ((x + 1) / 2) * (w - 1)
            y = ((y + 1) / 2) * (h - 1)
        else:
            x = ((x + 1) * w - 1) / 2
            y = ((y + 1) * h - 1) / 2

        x = paddle.reshape(x, [n, -1])
        y = paddle.reshape(y, [n, -1])

        x0 = paddle.floor(x).astype('float32')
        y0 = paddle.floor(y).astype('float32')
        x1 = x0 + 1
        y1 = y0 + 1

        x1_cast = x1.astype(grid.dtype)
        x0_cast = x0.astype(grid.dtype)
        y1_cast = y1.astype(grid.dtype)
        y0_cast = y0.astype(grid.dtype)
        wa = paddle.unsqueeze(((x1_cast - x) * (y1_cast - y)), 1)
        wb = paddle.unsqueeze(((x1_cast - x) * (y - y0_cast)), 1)
        wc = paddle.unsqueeze(((x - x0_cast) * (y1_cast - y)), 1)
        wd = paddle.unsqueeze(((x - x0_cast) * (y - y0_cast)), 1)

        # Apply default for grid_sample function zero padding
        im_padded = paddle.nn.functional.pad(im,
                                             pad=[1, 1, 1, 1],
                                             mode='constant',
                                             value=0)
        if im_padded.dtype != im.dtype:
            im_padded = paddle.cast(im_padded, im.dtype)
        padded_h = h + 2
        padded_w = w + 2
        # save points positions after padding
        x0, x1, y0, y1 = x0 + 1, x1 + 1, y0 + 1, y1 + 1

        # Clip coordinates to padded image size
        tensor_zero = paddle.full(shape=[1], dtype='float32', fill_value=0.0)
        tensor_padded_w = paddle.full(
            shape=[1], dtype='float32', fill_value=padded_w - 1)
        tensor_padded_h = paddle.full(
            shape=[1], dtype='float32', fill_value=padded_h - 1)
        x0 = paddle.where(x0 < 0, tensor_zero, x0)
        x0 = paddle.where(x0 > padded_w - 1, tensor_padded_w, x0)
        x1 = paddle.where(x1 < 0, tensor_zero, x1)
        x1 = paddle.where(x1 > padded_w - 1, tensor_padded_w, x1)
        y0 = paddle.where(y0 < 0, tensor_zero, y0)
        y0 = paddle.where(y0 > padded_h - 1, tensor_padded_h, y0)
        y1 = paddle.where(y1 < 0, tensor_zero, y1)
        y1 = paddle.where(y1 > padded_h - 1, tensor_padded_h, y1)
        im_padded = paddle.reshape(im_padded, [n, c, -1])

        x0_y0 = paddle.expand(
            paddle.unsqueeze((x0 + y0 * padded_w), 1), [-1, c, -1]).astype('int64')
        x0_y1 = paddle.expand(
            paddle.unsqueeze((x0 + y1 * padded_w), 1), [-1, c, -1]).astype('int64')
        x1_y0 = paddle.expand(
            paddle.unsqueeze((x1 + y0 * padded_w), 1), [-1, c, -1]).astype('int64')
        x1_y1 = paddle.expand(
            paddle.unsqueeze((x1 + y1 * padded_w), 1), [-1, c, -1]).astype('int64')

        # Ia = self.gather(im_padded, 2, x0_y0)
        # Ib = self.gather(im_padded, 2, x0_y1)
        # Ic = self.gather(im_padded, 2, x1_y0)
        # Id = self.gather(im_padded, 2, x1_y1)
        Ia = paddle.take_along_axis(im_padded, x0_y0, 2).astype('float32')
        Ib = paddle.take_along_axis(im_padded, x0_y1, 2).astype('float32')
        Ic = paddle.take_along_axis(im_padded, x1_y0, 2).astype('float32')
        Id = paddle.take_along_axis(im_padded, x1_y1, 2).astype('float32')

        return paddle.reshape((Ia * wa + Ib * wb + Ic * wc + Id * wd),
                              [n, c, gh, gw])
    def gather(self, x, dim, index):
        # index_shape = index.shape
        index_shape = paddle.shape(index)
        x_shape = paddle.shape(x)
        index_flatten = index.flatten()
        if dim < 0:
            dim = len(x.shape) + dim
        nd_index = []
        for k in range(len(x.shape)):
            if k == dim:
                nd_index.append(index_flatten)
            else:
                reshape_shape = [1] * len(x.shape)
                x_shape_k = x_shape[k]
                # x_shape_k = x.shape[k]
                reshape_shape[k] = x_shape[k]
                x_arange = paddle.arange(x_shape_k, dtype=index.dtype)
                x_arange = x_arange.reshape(reshape_shape)
                dim_index = paddle.expand(x_arange, index_shape).flatten()
                nd_index.append(dim_index)
        ind2 = paddle.transpose(paddle.stack(nd_index), [1, 0]).astype("int64")
        paddle_out = paddle.gather_nd(x, ind2).reshape(index_shape)
        return paddle_out

评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值