【pytorch】torch.cat和torch.stack的区别
torch.cat
import torch
a=torch.randn((1,3,4,4)) #假设代表了[N,c,w,h]
b=torch.cat((a,a)) #维度默认是0
# (2, 3, 4, 4)
c=torch.cat((a,a),dim=1)
# (1, 6, 4, 4)
接下来看一些维度不同的
import torch
a=torch.randn((1,3,4,4)
torch.cat
import torch
a=torch.randn((1,3,4,4)) #假设代表了[N,c,w,h]
b=torch.cat((a,a)) #维度默认是0
# (2, 3, 4, 4)
c=torch.cat((a,a),dim=1)
# (1, 6, 4, 4)
接下来看一些维度不同的
import torch
a=torch.randn((1,3,4,4)