torch.sum
import torch
import numpy as np
#
a = torch.ones((2,2))
b = np.array([[1,2,3],[1,1,1]])
c = torch.from_numpy(b)
interval_0 = torch.sum(c, dim=0, keepdim=True)
interval_1 = torch.sum(c, dim=1, keepdim=True)
print("a(2x2):")
print(a)
print("b:")
print(b)
print("c:")
print(c)
print("sum0:")
print(interval_0)
print("sum1:")
print(interval_1)
输出(重点关注sum0和sum1):
a(2x2):
tensor([[1., 1.],
[1., 1.]])
b:
[[1 2 3]
[1 1 1]]
c:
tensor([[1, 2, 3],
[1, 1, 1]])
sum0:
tensor([[2, 3, 4]])
sum1:
tensor([[6],
[3]])
参考:https://blog.csdn.net/qq_39463274/article/details/105145029
https://blog.csdn.net/SakuraHimi/article/details/104466849