x_linspace = torch.linspace(0.5, width - 0.5, width).view(1, width).expand(height, width)
y_linspace = torch.linspace(0.5, height - 0.5, height).view(height, 1).expand(height, width)
``print(x_linspace)
tensor([[0.5000, 1.5000, 2.5000, 3.5000, 4.5000],
[0.5000, 1.5000, 2.5000, 3.5000, 4.5000],
[0.5000, 1.5000, 2.5000, 3.5000, 4.5000],
[0.5000, 1.5000, 2.5000, 3.5000, 4.5000],
[0.5000, 1.5000, 2.5000, 3.5000, 4.5000],
[0.5000, 1.5000, 2.5000, 3.5000, 4.5000]])
print(y_linspace)
tensor([[0.5000, 0.5000, 0.5000, 0.5000, 0.5000],
[1.5000, 1.5000, 1.5000, 1.5000, 1.5000],
[2.5000, 2.5000, 2.5000, 2.5000, 2.5000],
[3.5000, 3.5000, 3.5000, 3.5000, 3.5000],
[4.5000, 4.5000, 4.5000, 4.5000, 4.5000],
[5.5000, 5.5000, 5.5000, 5.5000, 5.5000]])
torch.expand(m,n)
假如原张量为(m,1),扩充之后为(m,n)第1个维度复制n遍
假如原张量为(1,n),扩充之后为(m,n) 第0个维度复制m遍
torch.expand()
最新推荐文章于 2024-05-27 17:16:40 发布