一文彻底弄懂 PyTorch 的 `F.grid_sample`

在深度学习和计算机视觉任务中,经常需要对图像或特征图进行采样和变换。PyTorch 提供的 F.grid_sample 函数非常方便,用于根据指定的坐标从输入张量中采样特定点的值。本文将详细介绍如何使用 F.grid_sample,并通过两个具体例子解释其工作原理。

什么是 F.grid_sample

F.grid_sample 是 PyTorch 中的一个函数,用于根据给定的坐标网格对输入张量进行采样。它常用于图像变形、数据增强等任务。函数的核心思想是使用双线性插值从输入张量中提取指定坐标的值。

示例代码

以下代码展示了如何使用 F.grid_sample 从一个 4x4 的输入张量中采样特定点的值:

import torch
import torch.nn.functional as F

# 定义一个 4x4 的输入张量
input_tensor = torch.tensor([[[[1, 2, 3, 4],
                               [5, 6, 7, 8],
                               [9, 10, 11, 12],
                               [13, 14, 15, 16]]]], dtype=torch.float)

# 定义采样点,归一化坐标在 [-1, 1] 范围内
# 这里使用小数坐标进行采样
grid = torch.tensor([[[[-0.5, -0.5], [0.5, -0.5]],
                      [[-0.5, 0.5], [0.5, 0.5]]]], dtype=torch.float)

# 使用 F.grid_sample 进行采样
output = F.grid_sample(input_tensor, grid, align_corners=True)

print(output)

计算过程

假设输入张量的尺寸为 (4, 4),采样点坐标的归一化范围在 [-1, 1],我们将其转换为张量坐标的范围 [0, 3]

归一化坐标转换公式

归一化坐标转换公式如下:
x input = ( x grid + 1 ) ⋅ ( W − 1 ) 2 x_{\text{input}} = \frac{(x_{\text{grid}} + 1) \cdot (W - 1)}{2} xinput=2(xgrid+1)(W1)
y input = ( y grid + 1 ) ⋅ ( H − 1 ) 2 y_{\text{input}} = \frac{(y_{\text{grid}} + 1) \cdot (H - 1)}{2} yinput=2(ygrid+1)(H1)

示例计算 1:归一化采样点 [-0.5, -0.5]

对于归一化采样点 [-0.5, -0.5],我们将其转换为输入张量的实际坐标:

x input = ( − 0.5 + 1 ) ⋅ ( 4 − 1 ) 2 = 0.5 ⋅ 3 2 = 0.75 x_{\text{input}} = \frac{(-0.5 + 1) \cdot (4 - 1)}{2} = \frac{0.5 \cdot 3}{2} = 0.75 xinput=2(0.5+1)(41)=20.53=0.75
y input = ( − 0.5 + 1 ) ⋅ ( 4 − 1 ) 2 = 0.5 ⋅ 3 2 = 0.75 y_{\text{input}} = \frac{(-0.5 + 1) \cdot (4 - 1)}{2} = \frac{0.5 \cdot 3}{2} = 0.75 yinput=2(0.5+1)(41)=20.53=0.75

这样,归一化坐标 [-0.5, -0.5] 对应的输入张量实际坐标为 [0.75, 0.75]

假设采样点 (x, y) 对应输入张量的坐标 [0.75, 0.75],我们可以确定其周围的四个像素值:

  • 左上角像素 (0, 0)
  • 右上角像素 (0, 1)
  • 左下角像素 (1, 0)
  • 右下角像素 (1, 1)

计算权重 wxwy

  • wx = 0.75(x 坐标的小数部分)
  • wy = 0.75(y 坐标的小数部分)

使用双线性插值公式计算插值值:

top_left = input_tensor[0, 0, 0, 0]  # 1
top_right = input_tensor[0, 0, 0, 1]  # 2
bottom_left = input_tensor[0, 0, 1, 0]  # 5
bottom_right = input_tensor[0, 0, 1, 1]  # 6

value = (1 - 0.75) * (1 - 0.75) * 1 + 0.75 * (1 - 0.75) * 2 + (1 - 0.75) * 0.75 * 5 + 0.75 * 0.75 * 6
# 结果为 0.0625 + 0.375 + 0.9375 + 3.375 = 4.75

示例计算 2:归一化采样点 [0.5, -0.5]

对于归一化采样点 [0.5, -0.5],我们将其转换为输入张量的实际坐标:

x input = ( 0.5 + 1 ) ⋅ ( 4 − 1 ) 2 = 1.5 ⋅ 3 2 = 2.25 x_{\text{input}} = \frac{(0.5 + 1) \cdot (4 - 1)}{2} = \frac{1.5 \cdot 3}{2} = 2.25 xinput=2(0.5+1)(41)=21.53=2.25
y input = ( − 0.5 + 1 ) ⋅ ( 4 − 1 ) 2 = 0.5 ⋅ 3 2 = 0.75 y_{\text{input}} = \frac{(-0.5 + 1) \cdot (4 - 1)}{2} = \frac{0.5 \cdot 3}{2} = 0.75 yinput=2(0.5+1)(41)=20.53=0.75

这样,归一化坐标 [0.5, -0.5] 对应的输入张量实际坐标为 [2.25, 0.75]

假设采样点 (x, y) 对应输入张量的坐标 [2.25, 0.75],我们可以确定其周围的四个像素值:

  • 左上角像素 (2, 0)
  • 右上角像素 (2, 1)
  • 左下角像素 (3, 0)
  • 右下角像素 (3, 1)

计算权重 wxwy

  • wx = 0.25(x 坐标的小数部分)
  • wy = 0.75(y 坐标的小数部分)

使用双线性插值公式计算插值值:

top_left = input_tensor[0, 0, 2, 0]  # 9
top_right = input_tensor[0, 0, 2, 1]  # 10
bottom_left = input_tensor[0, 0, 3, 0]  # 13
bottom_right = input_tensor[0, 0, 3, 1]  # 14

value = (1 - 0.25) * (1 - 0.75) * 9 + 0.25 * (1 - 0.75) * 10 + (1 - 0.25) * 0.75 * 13 + 0.25 * 0.75 * 14
# 结果为 0.75 * 0.25 * 9 + 0.25 * 0.25 * 10 + 0.75 * 0.75 * 13 + 0.25 * 0.75 * 14
# 结果为 1.6875 + 0.625 + 7.3125 + 2.625 = 12.25

总结

通过上面的示例计算,我们可以看到如何将归一化坐标转换为输入张量的实际坐标,以及如何使用双线性插值计算采样点的值。具体示例展示了从输入张量 [0.5, -0.5][-0.5, -0.5] 处采样的详细过程。

F.grid_sample 通过双线性插值在输入张量中进行采样,即使坐标是小数,也能准确地计算出相应位置的值。这对于需要高精度变换和采样的任务非常有用。通过这种方法,我们可以实现高精度的图像处理和分析。

参考:
1.https://blog.csdn.net/qq_34914551/article/details/107559031
2.【通俗易懂】详解torch.nn.functional.grid_sample函数:可实现对特征图的水平/垂直翻转

### 使用TPS进行深度学习网络校正 #### TPS原理概述及其应用背景 薄板样条插值(TPS)是一种广泛应用于计算机视觉中的变形模型,能够有效地处理非线性形变。该方法通过最小化能量函数来找到最优的映射关系,在保持全局形状的同时允许局部细节调整[^1]。 对于深度学习而言,TPS被引入至空间变换网络(Spatial Transformer Networks, STN),作为其中的关键组件之一参与图像预处理工作。具体来说,《Robust Scene Text Recognition with Automatic Rectification》一文中提到的空间变换模块即采用了基于TPS的空间几何矫正机制,旨在解决场景文字识别任务中存在的视角扭曲等问题[^2]。 #### 实现过程 为了在神经网络框架内集成TPS功能,开发者们通常会借助于流行的机器学习库如PyTorch完成相应操作。下面给出了一种可能的方式: - **定义控制点集**:选取源图片和平面上若干对应的特征点作为控制点; - **构建参数矩阵W**:依据上述选定的控制点对建立方程组求解未知数w_i; - **计算径向基函数Φ(r)**:根据欧氏距离r确定每一对控制点间的相互影响程度; - **预测目标位置**:利用得到的权重系数和径向基函数表达式估计新坐标的精确值。 以下是采用Python编写的一个简单例子,展示了如何利用PyTorch实现这一流程: ```python import torch from scipy.spatial import distance_matrix def tps_transform(points_src, points_dst, grid_size=20): """Compute the TPS transformation matrix.""" def compute_U(distsq): return distsq * torch.log(distsq + 1e-7) num_points = len(points_src) P = torch.ones((num_points, 3)) P[:, 1:] = points_src K = compute_U(distance_matrix(points_src.cpu().numpy(), points_src.cpu().numpy())) L = torch.zeros(num_points + 3, num_points + 3).to(K.device) L[:num_points, :num_points] = K + torch.eye(num_points)*1e-3 L[:num_points, -3:] = P L[-3:, :num_points] = P.t() V = torch.cat([points_dst, torch.zeros(3, 2)], dim=0) weights = torch.solve(V.unsqueeze(-1), L)[0][:num_points] def transform_grid(grid_coords): U_dist = compute_U(((grid_coords.view(-1, 1, 2)-points_src)**2).sum(dim=-1)).view(*grid_coords.shape[:-1], -1) affine_part = (torch.matmul(P[:,:2].t(),weights[num_points:].unsqueeze(-1))).squeeze() nonrigid_part = ((U_dist*weights[:num_points]).sum(dim=-1)+affine_part).reshape_as(grid_coords) return nonrigid_part x_range = torch.linspace(min(points_src[:, 0]), max(points_src[:, 0]), steps=grid_size) y_range = torch.linspace(min(points_src[:, 1]), max(points_src[:, 1]), steps=grid_size) xv, yv = torch.meshgrid(x_range, y_range) coords = torch.stack([xv.flatten(), yv.flatten()],dim=-1).float().cuda() if next(model.parameters()).is_cuda else .float() transformed_coords = transform_grid(coords) warped_image = F.grid_sample(image_tensor, transformed_coords.reshape(y_range.size()[0], x_range.size()[0], 2)) return warped_image ``` 此段代码实现了基本的TPS转换逻辑,并将其应用于给定网格上的像素重定位。需要注意的是实际应用场景下还需考虑更多因素比如边界条件处理等。
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值