pytorch中的F.grid_sample使用方法及应用代码(align_corners参数详细解释)

首先Pytorch中grid_sample函数的接口声明如下:

torch.nn.functional.grid_sample(input, grid, mode='bilinear', padding_mode='zeros', align_corners=None)

简单来说就是,提供一个input的Tensor以及一个对应的grid网格,然后根据grid中每个位置提供的坐标信息(这里指input中pixel的坐标),将input中对应位置的像素值填充到grid指定的位置,得到最终的输出。

其中,input、grid和output分别表示输入Tensor,映射网格Tensor和输出Tensor,其尺寸如下所示:
i n p u t : ( N , C , H , W ) g r i d : ( N , H , W , 2 ) o u t p u t : ( N , C , H , W ) input: (N,C,H,W)\\ grid:(N,H,W,2)\\ output:(N,C,H,W) input:(N,C,H,W)grid:(N,H,W,2)output:(N,C,H,W)
其中grid中的2表示x,y方向的两个通道的映射,H,W是尺寸。这里一般会将(x,y)归一化到[-1,1],其中[-1,-1]代表左上角的像素,[1,1]代表右下角的像素。首先,以下代码先生成一个恒等的采样矩阵,然后进行采样。这里的恒等采样矩阵就是指左上角是[-1,-1],右上角是[-1,1],左下角是[1,-1],右下角是[1,1],并且元素均匀分布的矩阵,该矩阵是我们采样的原始参考grid网格。

import torch.nn.functional as F
import torch
featSize = 5
#生成恒等网络采样grid
gridY = torch.linspace(-1, 1, steps = featSize).view(1, -1, 1, 1).expand(1, featSize, featSize, 1)
gridX = torch.linspace(-1, 1, steps = featSize).view(1, 1, -1, 1).expand(1, featSize,  featSize, 1)
grid = torch.cat((gridX, gridY), dim=3).type(torch.float32)

#生成输入tensor(input),大小为[1,1,featSize,featSize],数据分布[1,2,...,featSize*featSize]
predict_roof = torch.linspace(1,featSize**2,steps=featSize**2,dtype=torch.float32).reshape(featSize,featSize)
predict_roof = (predict_roof.unsqueeze(dim=0)).unsqueeze(dim=0)

#使用恒等采样矩阵grid对input进行采样
trans_feature = F.grid_sample(predict_roof,grid,align_corners=True)
print(predict_roof)
print(grid.permute(0,3,1,2))
print(trans_feature)

输出结果为:可见恒等采样矩阵的采样结果是不改变原始的数据分布。
恒等采样矩阵采样结果
值得注意的是,F.grid_sample()中存在一个bool参数(align_corners)

  • 当该参数设置为True时:采样网格的坐标被视为指向像素的角点。这种设置通常用于对齐图像边缘的像素。
  • 当该参数设置为False时:采样网格的坐标被视为指向像素之间的中心点。这种设置通常用于保持图像的整体形状和分辨率。

以下对上述两个设置进行详细的说明:
首先,我们如何看待一副图像中像素的组成。这里又两种方式:1、看作是方块;2、看作是点。如下图所示。对于一个 5 × 5 5\times 5 5×5的一副图像。
像素分布示意图

  • align_corners = False:
    -原始像素可以看作是一个点,如上述右图所示。映射网格grid的数据点被视为指向像素之间的中心点。计算过程示意图如下所示。需要注意的是,原始图像大小为 5 × 5 5\times 5 5×5,输出Tensor大小为 6 × 6 6\times6 6×6,并且F.grid_sample()的padding方式是zeros填充方式。
    align_corners = False映射结果

上述输出tensor左上角的像素对应(-1,-1),右下角像素对应(1,1)。代码计算结果如下所示。
代码输出结果

  • align_corners = True:
    原始像素可以看作是一个像素方框,如上述左图所示。映射网格grid的数据点被视为指向像素之间的角点。计算过程如下所示。输入Tensor大小为 5 × 5 5\times 5 5×5,输出大小为 5 × 5 5\times 5 5×5
    align_corners = True计算示意图

左上角对应的grid坐标为(-1,-1),右下角对应的grid坐标为(1,1)。上述grid为恒等采样矩阵时,输出Tensor和输入Tensor是一样的。为了说明上述的运算过程,现在假设某一个像素点的grid映射坐标为(-0.9,-1),即上述示意图中的紫色点,该点距离左点0.1,距右点0.4。即有如下计算:
0.1 0.4 + 0.1 × 2 + 0.4 0.4 + 0.1 × 1 = 1.2 \frac{0.1}{0.4+0.1}\times2+\frac{0.4}{0.4+0.1}\times1 = 1.2 0.4+0.10.1×2+0.4+0.10.4×1=1.2
代码中设置映射grid[0,0,0,:] = [-0.9,-1],即输出左上角映射到[-0.9,1],经过以上计算,左上角点应该是1.2。结果如下所示。
代码运行示意图
总结:本质上来说,align_corners 设置映射的网格grid是否与输入Tensor对齐,False代表不对齐,True代表对齐。以后提供应用代码,今天有点累,想摆烂了。未完待续…

### torchvision.transforms.functional.grid_sample使用方法 `F.grid_sample` 是 PyTorch 中的一个函数,主要用于执行基于网格的采样操作。该功能对于图像变形、空间变换网络 (STN) 及其他涉及几何变换的任务非常有用[^1]。 #### 函数签名 ```python torch.nn.functional.grid_sample(input, grid, mode='bilinear', padding_mode='zeros', align_corners=None) ``` - `input`: 输入张量形状为 `(N, C, H_in, W_in)`,表示批量大小 N、通道数 C 和输入的高度宽度。 - `grid`: 形状为 `(N, H_out, W_out, 2)` 的张量,指定新位置上的像素坐标。 - `mode`: 插值模式,默认为 `'bilinear'`;也可以设置为 `'nearest'` 或 `'bicubic'`。 - `padding_mode`: 边界填充方式,默认为 `'zeros'`;还可以选择 `'border'` 或 `'reflection'`。 - `align_corners`: 如果设为 True,则保留角点对齐关系。 #### 应用场景 在图像生成领域,`grid_sample` 常被用来实现各种类型的图像转换效果: - **仿射变换**:可以创建旋转、缩放和平移的效果; - **透视变换**:模拟相机视角的变化; - **弹性形变**:模仿生物组织或其他软物质的行为特性; - **风格迁移**:调整源图片以匹配目标艺术作品的外观特征。 下面是一个简单的例子来展示如何应用此函数进行基本的空间变换: ```python import torch from torchvision import transforms import matplotlib.pyplot as plt def show_images(images, cols=1, titles=None): """显示多幅图像""" n_images = len(images) fig = plt.figure() for n, image in enumerate(images): a = fig.add_subplot(cols, np.ceil(n_images / float(cols)), n + 1) if image.ndim == 2: plt.gray() plt.imshow(image) if titles is not None: a.set_title(titles[n]) plt.axis('off') plt.show() # 创建随机噪声作为原始图像 img_tensor = torch.randn(1, 3, 256, 256) # 定义一个简单的线性变换矩阵(例如90度逆时针旋转) theta = torch.tensor([[[0., -1.], [1., 0.]]]) # 将 theta 转换成适合于 affine_grid 的形式 affine_matrix = theta.repeat(img_tensor.size(0), 1, 1).float() # 计算新的采样网格 grid = F.affine_grid(affine_matrix, img_tensor.size()) # 对原图施加变换 transformed_img = F.grid_sample(img_tensor, grid) show_images([ transforms.ToPILImage()(img_tensor.squeeze()), transforms.ToPILImage()(transformed_img.squeeze()) ], titles=['Original Image', 'Transformed Image']) ``` 上述代码片段展示了如何利用 `grid_sample` 来完成一次简单的旋转变换,并通过可视化工具查看变化前后的对比情况。
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值