【Pytorch】| Pytorch中softmax的dim的详细总结
作者:刘兴禄,清华大学博士在读
欢迎关注我们的微信公众号 运小筹
关于softmax的理解
Softmax的公式为
softmax
(
x
i
)
=
e
x
i
∑
i
e
x
i
\text{softmax} (x_i) = \frac{e^{x_i}}{\sum_{i}{e^{x_i}}}
softmax(xi)=∑iexiexi
因此,其满足下面两个条件:
- 0 ⩽ softmax ( x i ) ⩽ 1 0 \leqslant \text{softmax} (x_i)\leqslant 1 0⩽softmax(xi)⩽1;
- ∑ i softmax ( x i ) = 1 \sum_{i} {\text{softmax} (x_i)} =1 ∑isoftmax(xi)=1.
这个函数是为了实现对输出向量的归一化,将其标准化为概率的形式。
首先看看官方对tf.nn.functional.softmax(x,dim = -1)
的解释:
dim (python:int) – A dimension along which Softmax will be computed (so every slice along dim will sum to 1).
也就是说,在dim的纬度上,加和为1。比如,是对行加和为1,还是列加和为1。
我们来进行测试
一维向量:dim=0和dim=-1结果相同,dim=1和dim=2会报错
# 假设张量为一维张量: a = torch.tensor([1, 2, 3], dtype=float)
# 注意,在pytorch中计算softmax的时候,张量必须为小数,不能为int类型.需要提前转化好
a = torch.tensor([1, 2, 3], dtype=float)
soft_max_a = torch.nn.functional.softmax(a, dim = 0)
# soft_max_a = tensor([0.0900, 0.2447, 0.6652], dtype=torch.float64) 加和为1
# soft_max_a = torch.nn.functional.softmax(a, dim = 1) # 会报错: IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)
# soft_max_a = torch.nn.functional.softmax(a, dim = 2) # 会报错
soft_max_a = torch.nn.functional.softmax(a, dim = -1)
# soft_max_a = tensor([0.0900, 0.2447, 0.6652], dtype=torch.float64) # 加和为1
二维张量:dim=1和dim=-1结果相同,dim=2会报错
'''
假设张量为2维张量: a = torch.tensor([[1, 2, 3],
[4, 5, 6]], dtype=float)
'''
# 注意,在pytorch中计算softmax的时候,张量必须为小数,不能为int类型.需要提前转化好
a = torch.tensor([[1, 2, 3],[4, 5, 6]], dtype=float)
soft_max_a = torch.nn.functional.softmax(a, dim = 0) # 按列 加和为1
'''
soft_max_a = tensor([[0.0474, 0.0474, 0.0474],
[0.9526, 0.9526, 0.9526]], dtype=torch.float64)
'''
soft_max_a = torch.nn.functional.softmax(a, dim = 1) # 按行加和为1
'''
soft_max_a = tensor([[0.0900, 0.2447, 0.6652],
[0.0900, 0.2447, 0.6652]], dtype=torch.float64)
'''
# soft_max_a = torch.nn.functional.softmax(a, dim = 2)
# 报错: IndexError: Dimension out of range (expected to be in range of [-2, 1], but got 2)
soft_max_a = torch.nn.functional.softmax(a, dim = -1) # 按行加和为1
'''
soft_max_a = tensor([[0.0900, 0.2447, 0.6652],
[0.0900, 0.2447, 0.6652]], dtype=torch.float64)
'''
最终结论
上面只是对各种情况的探索,可以理解,但是不要生搬硬套。下面的总结才是正确的理解方式,只需要看这里就好。
dim的可选值有0,1, 2, -1,其中:
dim=0
: 第1个维度加和为1,也就是列dim=1
: 第2个维度加和为1,也就是行dim=2
: 第3个维度加和为1,每一个二维矩阵的对应元素加和为1dim=-1
: 最后一个维度加和为1。
– 如果输入向量是1维,则就是该维(也就是对第1维)加和为1;
– 如果输入向量是2维,则就是对列(也就是对第2维)加和为1;
– 依次类推…
欢迎关注我们的微信公众号 运小筹
公众号往期推文如下