torch.cat 和 torch.stack的区别
a = torch.Tensor([[[[1, 2, 3],[1, 2, 3],[1, 2, 3]],
[[4, 5, 6],[4, 5, 6],[4, 5, 6]]],
[[[7, 8, 9],[7, 8, 9],[7, 8, 9]],
[[10, 11, 12],[10, 11, 12],[10, 11, 12]]]])
print(a.size())
b = torch.stack((a, a))
c = torch.cat((a, a))
print(b.size())
print(c.size())
print(a)
print(b)
print(c)
#torch.Size([2, 2, 3, 3])
#stack
#torch.Size([2, 2, 2, 3, 3])#cat
#torch.Size([4, 2, 3, 3])
tensor([[[[ 1., 2., 3.],
[ 1., 2., 3.],
[ 1., 2., 3.]],
[[ 4., 5., 6.],
[ 4., 5., 6.],
[ 4., 5., 6.]]],
[[[ 7., 8., 9.],
[ 7., 8., 9.],
[ 7., 8., 9.]],
[[10., 11., 12.],
[10., 11., 12.],
[10., 11., 12.]]]])
tensor([[[[[ 1., 2., 3.],
[ 1., 2., 3.],
[ 1., 2., 3.]],
[[ 4., 5., 6.],
[ 4., 5., 6.],
[ 4., 5., 6.]]],
[[[ 7., 8., 9.],
[ 7., 8., 9.],
[ 7., 8., 9.]],
[[10., 11., 12.],
[10., 11., 12.],
[10., 11., 12.]]]],
[[[[ 1., 2., 3.],
[ 1., 2., 3.],
[ 1., 2., 3.]],
[[ 4., 5., 6.],
[ 4., 5., 6.],
[ 4., 5., 6.]]],
[[[ 7., 8., 9.],
[ 7., 8., 9.],
[ 7., 8., 9.]],
[[10., 11., 12.],
[10., 11., 12.],
[10., 11., 12.]]]]])
tensor([[[[ 1., 2., 3.],
[ 1., 2., 3.],
[ 1., 2., 3.]],
[[ 4., 5., 6.],
[ 4., 5., 6.],
[ 4., 5., 6.]]],
[[[ 7., 8., 9.],
[ 7., 8., 9.],
[ 7., 8., 9.]],
[[10., 11., 12.],
[10., 11., 12.],
[10., 11., 12.]]],
[[[ 1., 2., 3.],
[ 1., 2., 3.],
[ 1., 2., 3.]],
[[ 4., 5., 6.],
[ 4., 5., 6.],
[ 4., 5., 6.]]],
[[[ 7., 8., 9.],
[ 7., 8., 9.],
[ 7., 8., 9.]],
[[10., 11., 12.],
[10., 11., 12.],
[10., 11., 12.]]]])
如果是
a = torch.Tensor([[[[1, 2, 3],[1, 2, 3],[1, 2, 3]],
[[4, 5, 6],[4, 5, 6],[4, 5, 6]]],
[[[7, 8, 9],[7, 8, 9],[7, 8, 9]],
[[10, 11, 12],[10, 11, 12],[10, 11, 12]]]])
print(a.size())
b = torch.stack((a, a), dim=1)
c = torch.cat((a, a), dim=1)
print(b.size())
print(c.size())
print(a)
print(b)
print(c)
torch.Size([2, 2, 3, 3])
torch.Size([2, 2, 2, 3, 3])
torch.Size([2, 4, 3, 3])
tensor([[[[ 1., 2., 3.],
[ 1., 2., 3.],
[ 1., 2., 3.]],
[[ 4., 5., 6.],
[ 4., 5., 6.],
[ 4., 5., 6.]]],
[[[ 7., 8., 9.],
[ 7., 8., 9.],
[ 7., 8., 9.]],
[[10., 11., 12.],
[10., 11., 12.],
[10., 11., 12.]]]])
tensor([[[[[ 1., 2., 3.],
[ 1., 2., 3.],
[ 1., 2., 3.]],
[[ 4., 5., 6.],
[ 4., 5., 6.],
[ 4., 5., 6.]]],
[[[ 1., 2., 3.],
[ 1., 2., 3.],
[ 1., 2., 3.]],
[[ 4., 5., 6.],
[ 4., 5., 6.],
[ 4., 5., 6.]]]],
[[[[ 7., 8., 9.],
[ 7., 8., 9.],
[ 7., 8., 9.]],
[[10., 11., 12.],
[10., 11., 12.],
[10., 11., 12.]]],
[[[ 7., 8., 9.],
[ 7., 8., 9.],
[ 7., 8., 9.]],
[[10., 11., 12.],
[10., 11., 12.],
[10., 11., 12.]]]]])
tensor([[[[ 1., 2., 3.],
[ 1., 2., 3.],
[ 1., 2., 3.]],
[[ 4., 5., 6.],
[ 4., 5., 6.],
[ 4., 5., 6.]],
[[ 1., 2., 3.],
[ 1., 2., 3.],
[ 1., 2., 3.]],
[[ 4., 5., 6.],
[ 4., 5., 6.],
[ 4., 5., 6.]]],
[[[ 7., 8., 9.],
[ 7., 8., 9.],
[ 7., 8., 9.]],
[[10., 11., 12.],
[10., 11., 12.],
[10., 11., 12.]],
[[ 7., 8., 9.],
[ 7., 8., 9.],
[ 7., 8., 9.]],
[[10., 11., 12.],
[10., 11., 12.],
[10., 11., 12.]]]])
x.view(x.size(0), -1)的解释
def forward(self, x):
x = self.features(x)
x = x.view(x.size(0), -1)
x = self.output(x)
return x
torch.Size([4, 2048, 1, 1]) torch.Size([4, 2048]) torch.Size([4, 1000])#1000为分类数目
w = self.fc2(z)
print(w.size())
batch = w.size(0)
w = w.view(batch, self.num_branches, self.out_channels)
print(w.size())
w = self.softmax(w)
print(w.size())
w = w.unsqueeze(-1).unsqueeze(-1)
print(w.size())
torch.Size([4, 256, 1, 1])
torch.Size([4, 2, 128])
torch.Size([4, 2, 128])
torch.Size([4, 2, 128, 1, 1])