*的作用可以参考https://www.cnblogs.com/jony7/p/8035376.html
torch.max可以参考https://blog.csdn.net/Z_lbj/article/details/79766690
a.size()
# Out[134]: torch.Size([6, 4, 3])
torch.max(a, 0)[1].size()
# Out[135]: torch.Size([4, 3])
torch.max(a, 1)[1].size()
# Out[136]: torch.Size([6, 3])
torch.max(a, 2)[1].size()
# Out[137]: torch.Size([6, 4])
具体怎么比较的可以看下面
b
tensor([[[ 0., 1., 2., 3.],
[ 4., 5., 6., 7.],
[ 8., 9., 10., 11.]],
[[ 12., 13., 14., 15.],
[ 16., 17., 18., 19.],
[ 20., 21., 22., 23.]]])
torch.max(b,0)[0]
tensor([[ 12., 13., 14., 15.],
[ 16., 17., 18., 19.],
[ 20., 21., 22., 23.]])
torch.max(b,1)[0]
tensor([[ 8., 9., 10., 11.],
[ 20., 21., 22., 23.]])
torch.max(b,2)[0]
tensor([[ 3., 7., 11.],
[ 15., 19., 23.]])
相应的下标可以得到
b
tensor([[[ 0., 1., 2., 3.],
[ 4., 5., 6., 7.],
[ 8., 9., 10., 11.]],
[[ 12., 13., 14., 15.],
[ 16., 17., 18., 19.],
[ 20., 21., 22., 23.]]])
torch.max(b,0)[1]
tensor([[ 1, 1, 1, 1],
[ 1, 1, 1, 1],
[ 1, 1, 1, 1]])
torch.max(b,1)[1]
tensor([[ 2, 2, 2, 2],
[ 2, 2, 2, 2]])
torch.max(b,2)[1]
tensor([[ 3, 3, 3],
[ 3, 3, 3]])
torch.sum:
torch.sum(input) → Tensor
torch.sum(input, dim, out=None) → Tensor
参数:
input (Tensor) – 输入张量
dim (int) – 缩减的维度
out (Tensor, optional) – 结果张量
函数的输出是一个tensor
match
out:
tensor([[[ 0, 0, 2, 0],
[ 0, 0, 0, 0],
[ 0, 0, 0, 0]],
[[ 0, 0, 0, 0],
[ 0, 0, 0, 0],
[ 0, 0, 0, 0]]], dtype=torch.uint8)
torch.sum(match)
Out:
tensor(2)
torch.sum(match,0)
Out:
tensor([[ 0, 0, 2, 0],
[ 0, 0, 0, 0],
[ 0, 0, 0, 0]])
torch.sum(match,1)
Out:
tensor([[ 0, 0, 2, 0],
[ 0, 0, 0, 0]])
torch.sum(match,2)
Out:
tensor([[ 2, 0, 0],
[ 0, 0, 0]])
还要补充一点的就是item方法的使用:如果tensor只有一个元素那么调用item方法的时候就是将tensor转换成python的scalars;如果tensor不是单个元素的话那就会引发ValueError,如下面
b.item()
Traceback (most recent call last):
b.item()
ValueError: only one element tensors can be converted to Python scalars
torch.sum(b)
Out: tensor(276.)
torch.sum(b).item()
Out: 276.0
那么在python中的item方法一般是怎么样的呢?可参见https://blog.csdn.net/qq_34941023/article/details/78431376