torch.stack()函数其实等于是在堆叠数据。
对于一个二维的tensor类型数据来说,距离如下:
import torch
a=torch.tensor([[1,2,3],[4,5,6],[7,8,9]])
a
tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
这里的*维变量其实就等于是这个tensor里的unit数据的维度。
torch.cat(input,dim=0)
默认按行连接张量
torch.stack()函数其实等于是在堆叠数据。
对于一个二维的tensor类型数据来说,距离如下:
import torch
a=torch.tensor([[1,2,3],[4,5,6],[7,8,9]])
a
tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
这里的*维变量其实就等于是这个tensor里的unit数据的维度。
torch.cat(input,dim=0)
默认按行连接张量