Gumbel Softmax的作用
对离散的分布进行采样
假设如下场景:
模型训练过程中, 网络的输出为
p
=
[
0.1
,
0.7
,
0.2
]
p = [0.1, 0.7, 0.2]
p=[0.1,0.7,0.2], 三个数值分别为"向左", “向上”, "向右"的概率。 我们的决策可能是
y
=
a
r
g
m
a
x
(
p
)
y = argmax(p)
y=argmax(p), 也即选择"向上"这条决策。
但是,这样做会有两个问题:
- (1) a r g m a x ( ) argmax() argmax()的选择不具有随机性。同样的输出 p p p选择 100 100 100次,每次的结果都为"向上"。而按照概率为 0.7 0.7 0.7的含义, 100 100 100次应该有 70 70 70次左右的决策结果是选择"向上".
- (2) a r g m a x ( ) argmax() argmax()函数是不可导的。这样网络就无法通过反向传播进行学习。
而gumbel_softmax的作用就是解决上述这两个子问题.。
1.如何具有随机性
2.如何使得函数可导
参考文献
[1]算法学习之gumbel softmax
[2]Gumbel softmax trick (快速理解附代码)
[3]【Learning Notes】Gumbel 分布及应用浅析