Softmax
给定softmax的输入 ( z 1 , z 2 , . . . , z n ) (z_1,z_2,...,z_n) (z1,z2,...,zn),则输出为 f ( z 1 , f ( z 2 ) , . . . , f ( z n ) ) f(z_1,f(z_2),...,f(z_n)) f(z1,f(z2),...,f(zn)),其中 f ( z i ) , i ∈ [ 1 , n ] f(z_i),i\in[1,n] f(zi),i∈[1,n]的计算方式为:
f ( z i ) = e z i ∑ j = 1 n e z j f(z_i)=\frac{e^{z_i}}{\sum_{j=1}^ne^{z_j}} f(zi)=∑j=1nezjezi
Sampled softmax
目前流行的基于神经网络的机器翻译(NMT)模型,采用的是Encoder-Decoder结构,对于输入序列 x = ( x 1 , x 2 , . . . , x S n ) \boldsymbol{x}=(x_1,x_2,...,x_{S_n}) x=(x1,x2,...,xSn),生成对应的目标序列 y = ( y 1 , y 2 , . . . , y T n ) \boldsymbol{y}=({y_1,y_2,...,y_{T_n})} y=(y1,y2,...,yTn),模型的建模目标是最大化目标序列的条件概率。
l o g P ( y ∣ x ) = ∑ t = 1 T n P ( y t ∣ y < t , x ) logP(\boldsymbol{y}|\boldsymbol{x})=\sum_{t=1}^{T_n}P(y_t|y_{<t},\boldsymbol{x}) logP(y∣x)=t=1∑TnP(yt∣y<t,x)
对于包含N个样本的训练数据,模型的训练目标就是最大化整体条件概率:
θ ∗ = a r g m a x θ ∑ n = 1 N ∑ t = 1 T n l o g p ( y t n ∣ y < t n , x n ) \theta^*=argmax_{\theta}\sum_{n=1}^N\sum_{t=1}^{T_n}logp(y_t^n|y_{<t}^n,\boldsymbol{x}_n) θ∗=argmaxθn=1∑Nt=1∑Tnlogp(ytn∣y<tn,xn)
详细的模型结构在这里就不再展开,可以参考我之前的文章深度模型(二):Attention,在这里我们主要关注模型softmax层的计算。
softmax层输出目标序列中第 t t t位的符号概率分布,计算方式为:
p
(
y
t
∣
y
<
t
,
x
)
=
e
x
p
(
w
t
T
ϕ
(
y
t
−
1
,
z
t
,
c
t
)
+
b
t
)
Z
p(y_t|y_{<t},x)=\frac{exp(w_t^T\phi(y_{t-1},z_t,c_t)+b_t)}{Z}
p(yt∣y<t,x)=Zexp(wtTϕ(yt−1,zt,ct)+bt)
=
e
x
p
(
w
t
T
ϕ
(
y
t
−
1
,
z
t
,
c
t
)
+
b
t
)
∑
y
k
∈
V
e
x
p
(
w
k
T
ϕ
(
y
t
−
1
,
z
t
,
c
t
)
+
b
k
)
=\frac{exp(w_t^T\phi(y_{t-1},z_t,c_t)+b_t)}{\sum_{y_k\in V}exp(w_k^T\phi(y_{t-1},z_t,c_t)+b_k)}
=∑yk∈Vexp(wkTϕ(yt−1,zt,ct)+bk)exp(wtTϕ(yt−1,zt,ct)+bt)
其中V表示目标序列的词汇表, y t − 1 y_{t-1} yt−1表示目标序列中前一位的符号, z t z_t zt表示Decoder当前的隐状态, c t c_t ct表示Encoder隐状态的Attention值。
可以看出为了计算目标符号 y t y_t yt的条件概率,必须计算 Z Z Z值,这需要对词表 V V V中的符号进行遍历,计算量随着词表规模的变大而变大,目前词表的规模从几千到几万不等。
为了支持超大规模的词表,一个很自然的思路就是,能不能通过一些算法达到近似计算 Z Z Z值的目的呢。论文《On Using Very Large Target Vocabulary for Neural Machine Translation》提出了一种对 Z Z Z值的近似计算方法,这就是sampled softmax:
p
(
y
t
∣
y
<
t
,
x
)
=
e
x
p
(
w
t
T
ϕ
(
y
t
−
1
,
z
t
,
c
t
)
+
b
t
)
Z
^
p(y_t|y_{<t},x)=\frac{exp(w_t^T\phi(y_{t-1},z_t,c_t)+b_t)}{\widehat Z}
p(yt∣y<t,x)=Z
exp(wtTϕ(yt−1,zt,ct)+bt)
=
e
x
p
(
w
t
T
ϕ
(
y
t
−
1
,
z
t
,
c
t
)
+
b
t
)
∑
y
k
∈
V
′
e
x
p
(
w
k
T
ϕ
(
y
t
−
1
,
z
t
,
c
t
)
+
b
k
)
=\frac{exp(w_t^T\phi(y_{t-1},z_t,c_t)+b_t)}{\sum_{y_k\in V'}exp(w_k^T\phi(y_{t-1},z_t,c_t)+b_k)}
=∑yk∈V′exp(wkTϕ(yt−1,zt,ct)+bk)exp(wtTϕ(yt−1,zt,ct)+bt)
其中 V ′ V' V′就是采样得到的词表,词表规模要远小雨整体的词表 V V V,因此整体的词表 V V V的规模不再造成计算量增长的问题。采样方式和 V ′ V' V′的选择方式,以后有时间再补上。