torch.meshgrid
是 PyTorch 中的一个函数,用于创建多维网格坐标。
它接受多个一维张量作为输入,并根据指定的索引模式(通过 indexing
参数,默认为 'ij'
)生成相应的多维网格张量。
语法如下:
torch.meshgrid(*tensors, indexing='ij')
- 当
indexing='ij'
(默认)时,第一个输入张量沿着行方向扩展,第二个输入张量沿着列方向扩展,以此类推。 - 当
indexing='xy'
时,第一个输入张量沿着列方向扩展,第二个输入张量沿着行方向扩展,以此类推。
例如:
import torch
x = torch.tensor([1, 2, 3])
y = torch.tensor([4, 5])
# 默认 indexing='ij'
xx, yy = torch.meshgrid(x, y)
print(xx)
print(yy)
# indexing='xy'
xx_xy, yy_xy = torch.meshgrid(x, y, indexing='xy')
print(xx_xy)
print(yy_xy)
torch.meshgrid
在很多涉及坐标操作、图像处理、构建网格数据等任务中非常有用。