记录一些常见的函数(方法)的维度问题,方便日后查找。
1.torch.sum()
import torch
a = torch.randn(2, 3, 4)
b = torch.sum(a, dim=0)
print("a=", a)
print("a_shape=", a.shape)
print("b=", b)
print("b_shape=", b.shape)
结果如下:
a= tensor([[[-1.3656, -0.2668, -0.6884, -0.6200],
[ 1.8207, -0.2331, -0.5745, 1.0606],
[-1.1094, -2.0471, -0.4491, 0.0909]],[[ 0.0968, 0.3171, 0.0322, -0.8949],
[ 0.1554, 0.1303, 0.2219, 1.4291],
[-0.4944, -0.9078, 1.2116, 0.5914]]])
a_shape= torch.Size([2, 3, 4])
b= tensor([[-1.2688, 0.0503, -0.6563, -1.5149],
[ 1.9761, -0.1028, -0.3525, 2.4897],
[-1.6037, -2.9550, 0.7625, 0.6823]])
b_shape= torch.Size([3, 4])
可以看到指定求和维度dim=0时,求和后张量的维度与未指定的两个维度保持一致,相当于
a[0, :, :]与a[1, :, :]的对应元素求和。
2.torch.max()
import torch
a = torch.randn(2, 3, 4)
b = torch.max(a, dim=0)[0]
c = torch.max(a, dim=0)[1]
print("a=", a)
print("a_shape=", a.shape)
print("b=", b)
print("b_shape=", b.shape)
print("c=", c)
print("c_shape=", c.shape)
结果如下:
a= tensor([[[-9.4846e-01, 7.8566e-01, 1.6108e-01, 1.1380e+00],
[-1.3625e+00, 4.1124e-01, -1.4385e+00, -1.2473e+00],
[ 7.2747e-01, -1.9449e+00, 5.0707e-02, -1.2392e+00]],[[-9.2417e-01, 7.7893e-01, -1.2109e+00, 5.5173e-01],
[-3.5855e-04, 6.4165e-01, -1.3656e+00, 6.4585e-02],
[ 6.6905e-01, 1.4401e-01, -2.1580e+00, 2.7015e-01]]])
a_shape= torch.Size([2, 3, 4])
b= tensor([[-9.2417e-01, 7.8566e-01, 1.6108e-01, 1.1380e+00],
[-3.5855e-04, 6.4165e-01, -1.3656e+00, 6.4585e-02],
[ 7.2747e-01, 1.4401e-01, 5.0707e-02, 2.7015e-01]])
b_shape= torch.Size([3, 4])
c= tensor([[1, 0, 0, 0],
[1, 1, 1, 1],
[0, 1, 0, 1]])
c_shape= torch.Size([3, 4])
torch.max()返回最大值与其对应的索引。与tor