torch.meshgrid()函数解析
torch.meshgrid()的功能是生成网格,可以用于生成坐标。函数输入两个数据类型相同的一维张量,两个输出张量的行数为第一个输入张量的元素个数,列数为第二个输入张量的元素个数, 其中第一个输出张量填充第一个输入张量中的元素,各行元素相同;第二个输出张量填充第二个输入张量中的元素,各列元素相同。
注意: 当两个输入张量数据类型不同或维度不是一维时会报错。
看文字不懂没关系,直接上代码,理解了代码之后返回来看文字解释就会很清晰。
import torch
a = torch.tensor([1, 2, 3, 4])
print(a)
b = torch.tensor([4, 5, 6])
print(b)
x, y = torch.meshgrid(a, b)
print(x)
print(y)
结果显示:
tensor([1, 2, 3, 4])
tensor([4, 5, 6])
tensor([[1, 1, 1],
[2, 2, 2],
[3, 3, 3],
[4, 4, 4]])
tensor([[4, 5, 6],
[4, 5, 6],
[4, 5, 6],
[4, 5, 6]])
torch.stack() 函数解析
假如数据都是二维矩阵(平面),它可以把这些一个个平面按第三维(例如:时间序列)压成一个三维的立方体,而立方体的长度就是时间序列长度。
在pytorch
中,常见的拼接函数主要是两个,分别是:
stack()
cat()
这里解释不常见的stack()
参数
-
inputs : 待连接的张量序列。
注:python
的序列数据只有list
和tuple
。 -
dim : 新的维度, 必须在
0
到len(outputs)
之间。
注:len(outputs)
是生成数据的维度大小,也就是outputs
的维度值
结合下面代码理解:
import torch
# 假设是时间步T1的输出
T1 = torch.tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
# 假设是时间步T2的输出
T2 = torch.tensor([[10, 20, 30],
[40, 50, 60],
[70, 80, 90]])
print(torch.stack((T1,T2),dim=0))
print(torch.stack((T1,T2),dim=1))
print(torch.stack((T1,T2),dim=2))
print(torch.stack((T1,T2),dim=3))
#寄过如下
tensor([[[ 1, 2, 3],
[ 4, 5, 6],
[ 7, 8, 9]],
[[10, 20, 30],
[40, 50, 60],
[70, 80, 90]]])
tensor([[[ 1, 2, 3],
[10, 20, 30]],
[[ 4, 5, 6],
[40, 50, 60]],
[[ 7, 8, 9],
[70, 80, 90]]])
tensor([[[ 1, 10],
[ 2, 20],
[ 3, 30]],
[[ 4, 40],
[ 5, 50],
[ 6, 60]],
[[ 7, 70],
[ 8, 80],
[ 9, 90]]])
注意结合代码理解不同维度下的拼接方法