1.torch.unqueeze
torch.unsqueeze(input, dim=None) → Tensor
torch.unsqueeze :给input在原dim出添加一个长为1的轴newdim,示例如下:
由此可见,torch.unsequeeze作用只是添加一个长度为1的新轴
- dim = 0,x.shape: (2,3) -> (1,2,3)
- dim = 1,x.shape: (2,3) -> (2,1,3)
- dim = 2,x.shape: (2,3) -> (2,3,1)
- 概括为文字描述:在给dim=i之后的轴的数据套上一个方括号[],如dim=1,将x[0,:],x[1,:]分别用括号括起来,即使x在原0和1轴之间添加一个新轴。
- 另外逆向操作是torch.squeeze,可以去掉指定的一个长度为1的轴,即去掉一个括号
2.torch.cat
torch.cat(tensors, dim=0, *, out=None) → Tensor
tensors可以使列表或者元组, 同时这些tensor除了指定的dim外,其余各轴长度必须对应相同
torch.cat:用来拼接多个tensor,拼接时会按照指定的dim来进行,示例如下:
由此可见,
- dim的取值范围是[-x.shape,x.shape-1],负数就是倒数的意思(python特色下标)
- dim = 0,res.shape = (4,3),拼接dim=0轴上的元素,dim=0上的元素都是一个长度为3的向量,将次向量看成整体即为dim=0轴上的元素,最后x可视为[A,B]向量,x+1即为[A+1,B+1],将这两个拼接就得到了[A,B,A+1,B+1],展开就是最终shape=(4,3)的结果
- dim = 1,res.shape = (2,6), 拼接dim=1轴上的元素,dim=1上的元素都是一个个标量,将x,x+1对应dim=1轴上元素进行拼接,如x的0轴上的第一个元素是[0,1,2], 在dim=1轴上展开分别是0,1,2三个标量,将这三个标量与x+1对应0轴第一个元素在1轴展开的三个元素1,2,3进行拼接,继续如此拼接x,x+1的0轴的第二个元素,即得到了最终结果
- 总的来说torch.cat与各种语言中concat函数并无区别,只是维度可能更高了,被拼接的元素不再限于单个数字或字符这种标量,可能是一个多维张量,便于理解可以将该维度上的非标量元素看成整体,当成标量来拼接。
- 有两个逆向操作 torch.split() and torch.chunk()
3.torch.stack
torch.stack(tensors, dim=0, *, out=None) → Tensor
tensors:与torch.cat一样,可以使张量列表或元组
torch.stack:与torch.cat类似也是拼接,但不同的是:
- tensors中的张量必须shape都相同
- dim取值范围[-x.shape-1,x.shape]
- torch.stack会在原shape中的dim位置新增一个轴,并按dim轴来cat所有张量 。便于理解,可以认为torch.stack会先将所有张量做torch.unsqueeze(tensors,dim)的操作来新增轴,再将得到的张量列表做torch.cat(tensors_unsqueezed,dim)
- 示例如下
这里演示一个难理解点的dim=2的情况:
- 先通过torch.unsqueeze(x,dim=2), torch.unsqueeze(x+1,dim=2)来增加一个dim=2轴来得到两个新张量,记为y和z
2.通过torch.cat((y,z),dim=2)来拼接,最终得到torch.stack((x,x+1),dim=2)的结果
总的来说,torch.stack就是在一个新轴上cat多个张量,建议先理解unsqueeze和cat,这样才能彻底明白stack。
网上一搜torch.stack和torch.cat中文内容全是抄来抄去,明明都没写清楚 ,只简单说明shape的变化(这谁会理解不了),一堆人照抄假装理解了,可见中文环境是真滴差。这篇文章我参考了pytorch官方文档,也看了油管上的视频讲解,最终结合实践,终于搞懂了stack和cat的过程。
文章编写仓促,可能存在细微疏漏望多多包涵。
pytorch官方文档
油管视频
Stack vs Concat in PyTorch, TensorFlow & NumPy - Deep Learning Tensor Ops