Pytorch学习笔记【7】—统计属性
Pytorch笔记目录:点击进入
1. norm
a = torch.full([8],1)
b = a.view(2,4)
c = a.view(2,2,2)
print(b)
print(c)
print(a.norm(1),b.norm(1),c.norm(1))
print(a.norm(2),b.norm(2),c.norm(2))
out:
tensor([[1., 1., 1., 1.],
[1., 1., 1., 1.]])
tensor([[[1., 1.],
[1., 1.]],
[[1., 1.],
[1., 1.]]])
tensor(8.) tensor(8.) tensor(8.)
tensor(2.8284) tensor(2.8284) tensor(2.8284)
2. min max mean prod
获取张量的最大值、最小值、平均值、所有数值的乘积
a = torch.arange(8).view(2,4).float()
print(a)
print(a.min(),a.max(),a.mean(),a.prod())
out:
tensor([[0., 1., 2., 3.],
[4., 5., 6., 7.]])
tensor(0.) tensor(7.) tensor(3.5000) tensor(0.)
3. argmax and argmin
返回张量中最大值,最小值对应的索引
a = torch.rand(4,10)
print(a[0])
print(a.argmax())
print(a.argmin())
out:
tensor([0.4453, 0.1426, 0.5995, 0.3761, 0.5277, 0.1995, 0.2860, 0.0388, 0.9950,
0.5010])
tensor(8)
tensor(35)
a = torch.rand(2,3,4)
a.argmax()
out:
tensor(4)
4. dim and keepdim
可以作为max和min的属性值进行填写
a = torch.rand(4,10)
print("---")
print(a.max(dim=1))
print("---")
print(a.max(dim=1,keepdim=True))
out:
---
torch.return_types.max(
values=tensor([0.9836, 0.9341, 0.9224, 0.9486]),
indices=tensor([4, 9, 9, 5]))
---
torch.return_types.max(
values=tensor([[0.9836],
[0.9341],
[0.9224],
[0.9486]]),
indices=tensor([[4],
[9],
[9],
[5]]))
5. topk
topk返回第k小的数值,可以指定维度
print(a.topk(3,dim=1))
print(a.topk(3,dim=1,largest=False))
out:
torch.return_types.topk(
values=tensor([[0.9836, 0.9069, 0.6321],
[0.9341, 0.8707, 0.7575],
[0.9224, 0.8959, 0.7359],
[0.9486, 0.9137, 0.8548]]),
indices=tensor([[4, 2, 6],
[9, 5, 4],
[9, 1, 5],
[5, 2, 7]]))
torch.return_types.topk(
values=tensor([[0.1328, 0.1445, 0.3647],
[0.0516, 0.0944, 0.1258],
[0.0078, 0.0620, 0.3820],
[0.1812, 0.2916, 0.3018]]),
indices=tensor([[9, 0, 1],
[6, 1, 7],
[2, 0, 3],
[4, 6, 8]]))
6. kthvalue
返回最小排行榜里的前k个,可以指定维度
print(a.shape)
print(a.kthvalue(2,dim=1))
out:
torch.Size([4, 10])
torch.return_types.kthvalue(
values=tensor([0.1445, 0.0944, 0.0620, 0.2916]),
indices=tensor([0, 1, 0, 6]))