无论dim是个tuple还是一个int型整数,都是要让对应的维度消失。
举例:
a = torch.tensor([
[
[
[1,2,3],
[4,5,6]
]
]
])
print(a)
a.shape
输出:
tensor([[[[1, 2, 3],
[4, 5, 6]]]])
torch.Size([1, 1, 2, 3])
对dim=2求和,就让dim=2消失,dim=2对应的值是2,
它消失后的shape就是[1,1,3],怎么样让dim=2消失,1+4,2+5,3+6,这样原来的[1,2,3],[4,5,6]变成了[5,7,9],原来的dim=2消失了,它的位置被原来的dim=3接替,新的shape就成了[1,1,3],如下所示:
b = torch.sum(a, dim=2)
print(b)
print(b.shape)
输出:
tensor([[[5, 7, 9]]])
torch.Size([1, 1, 3])
同样可以对dim=3求和:
c = torch.sum(a, dim=3)
print(c)
print(c.shape)
输出:
tensor([[[ 6, 15]]])
torch.Size([1, 1, 2])
如果dim是个tuple, 那么就让tuple里所有的数对应的维度消失,比如dim=(2,3),那么2,3维度没了,也就是原来a的shape由[1,1,2,3]变成[1,1],如下所示:
d = torch.sum(a, dim=(2,3))
print(d)
print(d.shape)
输出:
tensor([[21]])
torch.Size([1, 1])
21是同时对第2,3维度求和的结果,即1+2+3+4+5+6=21。