用法:
torch.stack((tensor1,tensor2),dim=?)
torch.cat((tensor1,tensor2),dim=?)
dim默认为0
首先要知道dim代表什么意思:
一般情况下,dim最多包括batch_size,channel,height,width这四项
对应下标0,1,2,3
torch.stack和torch.cat都是用于拼接的,核心不同在于使用stack后,原来的张量会增加一维,比如把两个3 * 3(二维)的tensor用torch.stack在dim0拼接,拼接后的tensor的形状是2 * 3 * 3,也就是说每个3 * 3独自成了一个channel;
如果用torch.cat来拼接这两个3 * 3的tensor的话,dim=0就是按行拼接,就相当于把其中一个tensor放到了另一个的下面,最后的大小是6*3,二维。
测试结果如下: