dim的定义
dim
表示维度
x = torch.randn(2, 3, 3)
print(x)
print(x.size())
print(x.dim())
输出:
tensor([[[-1.6943, -2.1487, 1.2332],
[-0.2261, -0.1596, 1.5513],
[ 2.0383, -0.6982, -2.1481]],
[[ 0.4201, -2.7373, 0.2424],
[-1.1152, 1.3682, -1.8322],
[ 0.1957, -0.2920, 0.1845]]])
torch.Size([2, 3, 3])
3
这样看着不是很清晰,但如果将[]
格式化:
[
[
[-1.6943, -2.1487, 1.2332],
[-0.2261, -0.1596, 1.5513],
[ 2.0383, -0.6982, -2.1481]
],
[
[ 0.4201, -2.7373, 0.2424],
[-1.1152, 1.3682, -1.8322],
[ 0.1957, -0.2920, 0.1845]
]
]
- 维度
(2, 3, 3)
就很明显了, 是从矩阵的外部到内部 - 而
x.dim() = 3
意味着x
有三个维度,dim = (0, 1, 2)
,0
对应着x.size()
中的(2
, 3, 3)1
对应着x.size()
中的(2,3
, 3)2
对应着x.size()
中的(2, 3,3
)
dim的理解
当dim = 0
时, 指的是 x(3, 3)
也就是:
x = torch.randn(2, 3, 3)
print(x)
for i in x:
print(i)
print(i.size())
输出:
tensor([[[-1.4251, -0.8321, 1.0230],
[ 0.2008, 0.5929, -0.7696],
[-0.3721, -1.0837, -0.6642]],
[[-0.5337, 0.7808, 0.4419],
[-0.4683, 0.3847, 0.0747],
[ 1.0156, -0.4933, 1.5340]]])
tensor(
[
[-1.4251, -0.8321, 1.0230],
[ 0.2008, 0.5929, -0.7696],
[-0.3721, -1.0837, -0.6642]
]
)
torch.Size([3, 3])
tensor(
[
[-0.5337, 0.7808, 0.4419],
[-0.4683, 0.3847, 0.0747],
[ 1.0156, -0.4933, 1.5340]
]
)
torch.Size([3, 3])
所以说当dim=0
时, 相当于去除x
中的dim = 0
的维度
验证
- torch.argmax(tensor)
返回tensor中值最大的数的下标, 比较的是同型张量
Example:
>>> x = torch.tensor([1, 5, 8, 4, 6])
>>> torch.argmax(x)
tensor(2)
import torch
x = torch.randn(2, 3, 3)
print(x)
print('='*50, end='\n\n')
for i in x:
print(i)
print(i.size())
print('='*50, end='\n\n')
print(x.size())
print(x.dim())
print('='*50, end='\n\n')
y = torch.argmax(x, dim=0)
print(y)
print(y.size())
输出:
tensor(
[
[
[-1.3918, 0.0620, -0.4111],
[ 1.9623, -1.3399, -0.4673],
[-0.0185, -1.9024, 0.1340]
],
[
[ 0.7135, -0.5290, -0.7656],
[ 0.2642, 0.5956, -0.0718],
[-0.7465, -0.8098, -0.0874]
]
]
)
==================================================
tensor([[-1.3918, 0.0620, -0.4111],
[ 1.9623, -1.3399, -0.4673],
[-0.0185, -1.9024, 0.1340]])
torch.Size([3, 3])
tensor([[ 0.7135, -0.5290, -0.7656],
[ 0.2642, 0.5956, -0.0718],
[-0.7465, -0.8098, -0.0874]])
torch.Size([3, 3])
==================================================
torch.Size([2, 3, 3])
3
==================================================
tensor([[1, 0, 0],
[0, 1, 1],
[0, 1, 0]])
torch.Size([3, 3])
-
分析一下
y[0] = [1, 0, 0]
, 为什么呢?
有两种想法:- 它比较的是
[-1.3918, 0.0620, -0.4111]
与[ 0.7135, -0.5290, -0.7656]
其中:
[-1.3918, 0.7135], 0.7135比较大, 所以返回1
[0.0620, -0.5290], 0.0620比较大, 所以返回0
[-0.4111, -0.7656], -0.4111比较大, 所以返回0
- 如果比较的是
x[i]
中的每一列, 得到的是2x3
的输出, 例如x[0]
:
[-1.3918, 0.0620, -0.4111], [ 1.9623, -1.3399, -0.4673], [-0.0185, -1.9024, 0.1340]
比较每一列, 经过
torch.argmax
得到的是[1, 0, 2]
- 它比较的是
-
如果按照去掉
dim = 0
的部分,x'
:[ [-1.3918, 0.0620, -0.4111], [ 1.9623, -1.3399, -0.4673], [-0.0185, -1.9024, 0.1340] ], [ [ 0.7135, -0.5290, -0.7656], [ 0.2642, 0.5956, -0.0718], [-0.7465, -0.8098, -0.0874] ]
也就是两个
size = (3, 3)
的tensor
, 这为什么不是第二种情况就比较合理了
因为比较的是两个tensor
, 而第二种情况是分别在一个tensor
内的比较, 再将两个tensor
的比较结果合并- 总结: 比较的是去掉指定维度后的第一个维度, 比如这里的:(
2
, 3, 3) -> (3
, 3), 得到的结果的size
是去掉指定dim
的size
- 总结: 比较的是去掉指定维度后的第一个维度, 比如这里的:(
-
如果只有两个维度, 或许会好理解一些:
import torch x = torch.randn(2,3) print(x) y = torch.argmax(x, dim=0) print(y) print(y.size())
输出:
tensor( [ [ 0.0251, -0.3640, 0.1965], [ 0.6902, 0.9846, 0.2035] ] ) tensor([1, 1, 1]) torch.Size([3])
去掉
dim = 0
, 比较的就是[ 0.0251, -0.3640, 0.1965]
和[ 0.6902, 0.9846, 0.2035]
dim = (2
, 3) -> dim(3
) -
这时候再回来看上面3个维度的例子:
[ [-1.3918, 0.0620, -0.4111], [ 1.9623, -1.3399, -0.4673], [-0.0185, -1.9024, 0.1340] ], [ [ 0.7135, -0.5290, -0.7656], [ 0.2642, 0.5956, -0.0718], [-0.7465, -0.8098, -0.0874] ]
比较两者时相当于在下面的
tensor
做torch.argmax()
[ [-1.3918, 0.0620, -0.4111], [ 0.7135, -0.5290, -0.7656] ]