在进行cv相关实验中我们用的比较多的都是torch.cat()和torch.stack()函数。其中cat()函数的功能是在当前维度进行数据的拼接。stack()函数首先将当前维度及其以后维度的数据向后移动一位,将该位置的大小修改为1然后在进行拼接。这两个函数的区别是cat()函数不会进行维度扩展,stack()函数会进行维度扩展。以下从代码角度去理解这两个函数。
import numpy as np
import torch
# 创建两个张量并将其cat到一起。
def create_cat():
t1 = torch.zeros((3, 3))
t2 = torch.ones((3, 3))
t_cat = torch.cat([t1, t2], dim=1)
print(t_cat,t_cat.shape)
# 创建两个张量并将其stack到一起。
def create_stack():
t1 = torch.zeros((3,3))
t2 = torch.ones((3,3))
t_stack = torch.stack([t1,t2],dim=1)
print(t_stack,t_stack.shape)
if __name__ == '__main__':
create_cat()
create_stack()
冲上述图可以看到cat()和stack()都是在1维进行拼接的,cat()函数是直接在1维拼接原数据由两个[3,3]大小的数据变成了[3,6]的。stack()函数将数据[3,3]扩充成两个[3,1,3]然后在1维进行拼接变成了[3,2,3].
以下是画图表示,大家凑合着看哈。