【万字长文】Word2Vec计算详解(二)Skip-gram模型
写在前面
本篇介绍Word2Vec中的第二个模型Skip-gram模型
【万字长文】Word2Vec计算详解(一)CBOW模型 markdown行 9000+
【万字长文】Word2Vec计算详解(二)Skip-gram模型 markdown行 12000+
【万字长文】Word2Vec计算详解(三)分层Softmax与负采样 markdown行 18000+
Skip-gram模型
Skip-gram 模型是自然语言处理(NLP)中 Word2Vec 的一种重要模型。与 CBOW 正好反过来,Skip-gram的主要思想是通过输入某个单词,要求预测它的上下文单词。
模型结构
Skip-gram 模型的输入是目标单词的one-hot向量,通过线性变换形成预测上下文单词的向量,然后再通过一次线性变换得到每一个上下文单词的得分表,最后经过多分类得到要预测的上下文单词。Skip-gram 的模型结构如下图所示。
预处理
在正式介绍模型输入前,需要简单介绍模型输入前的处理。模型的预处理与 CBOW 模型的预处理一致,需要得到词汇表的 one-hot表示,这里详细可以参考 CBOW 模型的预处理。
模型输入
在模型中,将目标单词表示为独热编码(one-hot encoding)向量然后作为 Skip-gram 模型的输入 x i x_i xi, x i ∈ R V × 1 x_i \in \mathbb{R}^{V \times 1} xi∈RV×1, i i i为目标单词所在位置。
权重输入层
在这一层,我们将的得到 one-hot 编码的目标单词
x
i
∈
R
V
×
1
x_i \in \mathbb{R}^{V \times 1}
xi∈RV×1 与隐藏层的权重输入矩阵
W
W
W 相乘再加上置偏值
b
b
b 得到隐藏层向量
h
h
h。其中
W
∈
R
D
×
V
W \in \mathbb{R}^{D \times V}
W∈RD×V,
b
∈
R
D
×
1
b \in \mathbb{R}^{D \times 1}
b∈RD×1,
h
∈
R
D
×
1
h \in \mathbb{R}^{D \times 1}
h∈RD×1。写成矩阵的形式为
h
=
W
x
i
+
b
h = Wx_i + b
h=Wxi+b
权重输出层
我们将得到 h h h 与隐藏层的权重输出矩阵 W j ′ ∈ R V × D W_j' \in \mathbb{R}^{V \times D} Wj′∈RV×D 相乘再加上置偏值 b j ′ ∈ R V × 1 b_j' \in \mathbb{R}^{V \times 1} bj′∈RV×1 得到多个上下文单词得分的向量 S j ∈ R V × 1 S_j \in \mathbb{R}^{V \times 1} Sj∈RV×1, S = ( S 1 , S 2 , … , S 2 C ) T S = ( S_1, S_2, \dots, S_{2C})^T S=(S1,S2,…,S2C)T。其中 W j W_j Wj 表示为位置索引为 j j j 处的预测的上下文单词的权重输出矩阵, b j b_j bj 表示为位置索引为 j j j 处的预测的上下文单词的权重矩阵对应的置偏, S j S_j Sj 表示为位置索引为 j j j 处的预测的上下文单词的得分。其中, j = 1 , 2 , 3 , ⋯ , 2 C j = 1,2,3,\dotsm,2C j=1,2,3,⋯,2C, C C C为窗口大小,要预测窗口大小为 C C C 的上下文,就要预测位置为 j j j上的前 C C C个单词和后 C C C个单词,总共为 2 C 2C 2C 个单词。将上面运算写成矩阵的形式为
S
j
=
W
j
′
h
+
b
j
′
S_j = W_j'h + b_j'
Sj=Wj′h+bj′
其中
j
=
1
,
2
,
…
,
C
j = 1,2,\dots,C
j=1,2,…,C为上文索引,
j
=
C
+
1
,
C
+
2
,
…
,
2
C
j = C+1,C+2,\dots,2C
j=C+1,C+2,…,2C为下文索引,最后构成总的上下文索引。我们定义
S
j
=
(
S
j
(
0
)
,
S
j
(
1
)
,
…
,
S
j
(
V
−
1
)
)
T
S_j = (S_j(0),S_j(1),\dots,S_j(V-1))^T
Sj=(Sj(0),Sj(1),…,Sj(V−1))T,方便后面使用。
Softmax层
我们将输出层得到的的得分
S
j
S_j
Sj用 Softmax 处理为概率
P
j
P_j
Pj。
P
j
=
(
P
j
(
0
)
,
P
j
(
1
)
,
…
,
P
j
(
V
−
1
)
)
T
P_j = (P_j(0), P_j(1), \dots, P_j(V-1))^T
Pj=(Pj(0),Pj(1),…,Pj(V−1))T,
P
j
P_j
Pj 表示位置索引为
j
j
j 处的上下文单词的概率向量。其中
P
j
∈
R
V
×
1
P_j \in \mathbb{R}^{V \times 1}
Pj∈RV×1,
S
j
(
k
)
S_j(k)
Sj(k)表示得分向量
S
j
S_j
Sj 第
k
k
k 行对应位置的得分值。
P
j
(
k
)
P_j(k)
Pj(k)表示向量
P
j
P_j
Pj 第
k
k
k 行对应位置的概率值。Softmax 公式见 (\ref{SG02})。运算写成矩阵的形式为
P
j
(
k
)
=
Softmax
(
S
j
)
=
exp
(
S
j
(
k
)
)
∑
l
=
0
V
−
1
exp
(
S
j
(
l
)
)
P_j(k) = \text{Softmax}(S_{j}) =\frac{\exp(S_{j}(k))}{ \sum\limits_{l=0}^{V-1} \exp(S_j(l))}
Pj(k)=Softmax(Sj)=l=0∑V−1exp(Sj(l))exp(Sj(k))
模型输出
模型的输出是在 P P P 中取出最大概率对应位置的值设为1,其他位置设置为0,我们将得到一个one-hot编码。从该one-hot编码我们可以找到对应的单词,我们将其作为预测的上下文单词结果。这就是 Skip-gram 模型的输出。
简单的Skip-gram例子
下面给定一个例子来解释 Skip-gram 模型的计算。假设语料库为 text = ‘The cat plays in the garden, and the cat chases the mouse in the garden.’ 我们使用预处理给处给出的函数 preprocess 和 convert_one_hot 进行处理,分别得到以下结果。
index | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 |
---|---|---|---|---|---|---|---|---|---|---|
x i x_i xi | x 0 x_0 x0 | x 1 x_1 x1 | x 2 x_2 x2 | x 3 x_3 x3 | x 4 x_4 x4 | x 5 x_5 x5 | x 6 x_6 x6 | x 7 x_7 x7 | x 8 x_8 x8 | x 9 x_9 x9 |
word | the | cat | plays | in | garden | , | and | chases | mouse | . |
preprocess 函数得到后的结果(词汇表)
由上表展示了词汇表的信息,我们得到词汇表的大小 V = 10 V = 10 V=10。下面是 转换得到的one-hot矩阵我们标记其为 X X X。 X X X 中对应的一列为相应索引单词的 one-hot向量,即用 x i x_i xi表示该索引位置为 i i i 的one-hot向量。
X o n e h o t = ( x 0 , x 1 , x 2 , x 3 , x 4 , x 5 , x 6 , x 7 , x 8 , x 9 ) = [ 1 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 1 ] X_{onehot} = (x_0, x_1, x_2,x_3,x_4,x_5,x_6, x_7,x_8,x_9) = \begin{bmatrix} 1&0&0&0&0&0&0&0&0&0\\ 0&1&0&0&0&0&0&0&0&0\\ 0&0&1&0&0&0&0&0&0&0\\ 0&0&0&1&0&0&0&0&0&0\\ 0&0&0&0&1&0&0&0&0&0\\ 0&0&0&0&0&1&0&0&0&0\\ 0&0&0&0&0&0&1&0&0&0\\ 0&0&0&0&0&0&0&1&0&0\\ 0&0&0&0&0&0&0&0&1&0\\ 0&0&0&0&0&0&0&0&0&1 \end{bmatrix} Xonehot=(x0,x1,x2,x3,x4,x5,x6,x7,x8,x9)= 1000000000010000000000100000000001000000000010000000000100000000001000000000010000000000100000000001
我们假设窗口大小 C = 2 C = 2 C=2,隐藏层的维数 D = 4 D = 4 D=4 ,并且给定 “plays” 的上下文进行预测。我们可以得到模型输入是 x 2 x_2 x2,对应单词为 “plays”。则 X = ( x 2 ) = ( 0 , 0 , 1 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ) T X = (x_2) = (0,0,1,0,0,0,0,0,0,0)^T X=(x2)=(0,0,1,0,0,0,0,0,0,0)T。 我们对输入权重权重矩阵 W W W 和置偏值 b b b 进行初始化。
W = [ − 0.2047 0.4789 − 0.5194 − 0.5557 1.9657 1.3934 0.0929 0.2817 0.769 1.2464 1.0071 − 1.2962 0.2749 0.2289 1.3529 0.8864 − 2.0016 − 0.3718 1.669 − 0.4385 − 0.5397 0.4769 3.2489 − 1.0212 − 0.577 0.1241 0.3026 0.5237 0.0009 1.3438 − 0.7135 − 0.8311 − 2.3702 − 1.8607 − 0.8607 0.5601 − 1.2659 0.1198 − 1.0635 0.3328 ] W = \begin{bmatrix} -0.2047 & 0.4789 & -0.5194 & -0.5557 & 1.9657 & 1.3934 & 0.0929 & 0.2817 & 0.769 & 1.2464\\ 1.0071 & -1.2962 & 0.2749 & 0.2289 & 1.3529 & 0.8864 & -2.0016 & -0.3718 & 1.669 & -0.4385\\ -0.5397 & 0.4769 & 3.2489 & -1.0212 & -0.577 & 0.1241 & 0.3026 & 0.5237 & 0.0009 & 1.3438\\ -0.7135 & -0.8311 & -2.3702 & -1.8607 & -0.8607 & 0.5601 & -1.2659 & 0.1198 & -1.0635 & 0.3328 \end{bmatrix} W= −0.20471.0071−0.5397−0.71350.4789−1.29620.4769−0.8311−0.51940.27493.2489−2.3702−0.55570.2289−1.0212−1.86071.96571.3529−0.577−0.86071.39340.88640.12410.56010.0929−2.00160.3026−1.26590.2817−0.37180.52370.11980.7691.6690.0009−1.06351.2464−0.43851.34380.3328
b = [ 1.1274 − 0.5683 0.3093 − 0.5773 ] b = \begin{bmatrix} 1.1274 \\ -0.5683 \\ 0.3093 \\ -0.5773 \end{bmatrix} b= 1.1274−0.56830.3093−0.5773
接下来是权重输入层的运算。我们将 W W W 与 X X X 进行矩阵乘法运算计算得到 h h h。
h = W X + b = [ − 0.5194 0.2748 3.2488 − 2.3701 ] T h = WX + b = \begin{bmatrix} -0.5194 & 0.2748 & 3.2488 & -2.3701 \end{bmatrix}^T h=WX+b=[−0.51940.27483.2488−2.3701]T
接下来是权重输出层,我们将 W j ′ W_j' Wj′ 和 b j ′ b_j' bj′进行随机初始化。然后进行运算 W j ′ h + b j ′ W_j'h + b_j' Wj′h+bj′ 得到预测对应上下文的评分 S j S_j Sj。由于窗口大小为 C = 2 C = 2 C=2,所以我们有 2 × C = 2 × 2 = 4 2 \times C = 2 \times 2 = 4 2×C=2×2=4个 需要预测的上下文单词。也就是说有 W 1 ′ , W 2 ′ , W 3 ′ , W 4 ′ W_1',W_2',W_3',W_4' W1′,W2′,W3′,W4′和 b 1 ′ , b 2 ′ , b 3 ′ , b 4 ′ b_1',b_2',b_3',b_4' b1′,b2′,b3′,b4′ 四组权重输出矩阵和置偏值,下面我们计算 S 1 , S 2 , S 3 , S 4 S_1, S_2, S_3, S_4 S1,S2,S3,S4。计算公式为 S j = W j ′ h + b j S_j = W_j'h + b_j Sj=Wj′h+bj。
以 S 1 S_1 S1的计算为例,我们首先初始化 W 1 ′ W_1' W1′和 b 1 ′ b_1' b1′。
W
1
′
=
[
−
2.3594
−
1.3070
0.3312
−
0.0118
−
1.5491
0.8625
0.8529
−
0.6524
0.7236
−
0.6222
0.0513
1.0107
−
0.1315
−
0.1149
0.1181
−
1.5656
−
0.4825
−
0.5894
0.9299
0.2204
−
2.2528
−
0.2745
−
0.4170
1.6347
1.3067
−
0.8239
0.1869
0.6803
−
0.3042
−
0.3674
−
0.4162
−
0.7769
1.9207
0.7273
−
0.9192
−
0.5674
1.2098
−
0.3957
0.8387
−
1.0209
]
,
b
1
′
=
[
−
1.1686
−
0.7519
−
0.4937
−
0.8468
−
0.4456
0.6254
−
0.3501
−
1.0522
−
1.0623
−
1.3675
]
W_1' = \begin{bmatrix} -2.3594 & -1.3070 & 0.3312 & -0.0118\\ -1.5491 & 0.8625 & 0.8529 & -0.6524\\ 0.7236 & -0.6222 & 0.0513 & 1.0107\\ -0.1315 & -0.1149 & 0.1181 & -1.5656\\ -0.4825 & -0.5894 & 0.9299 & 0.2204\\ -2.2528 & -0.2745 & -0.4170 & 1.6347\\ 1.3067 & -0.8239 & 0.1869 & 0.6803\\ -0.3042 & -0.3674 & -0.4162 & -0.7769\\ 1.9207 & 0.7273 & -0.9192 & -0.5674 \\ 1.2098 & -0.3957 & 0.8387 & -1.0209 \end{bmatrix}, b_1' =\begin{bmatrix} -1.1686 \\ -0.7519 \\ -0.4937 \\ -0.8468 \\ -0.4456 \\ 0.6254 \\ -0.3501 \\ -1.0522 \\ -1.0623 \\ -1.3675 \end{bmatrix}
W1′=
−2.3594−1.54910.7236−0.1315−0.4825−2.25281.3067−0.30421.92071.2098−1.30700.8625−0.6222−0.1149−0.5894−0.2745−0.8239−0.36740.7273−0.39570.33120.85290.05130.11810.9299−0.41700.1869−0.4162−0.91920.8387−0.0118−0.65241.0107−1.56560.22041.63470.6803−0.7769−0.5674−1.0209
,b1′=
−1.1686−0.7519−0.4937−0.8468−0.44560.6254−0.3501−1.0522−1.0623−1.3675
于是根据公式
S
j
=
W
j
′
h
+
b
j
′
S_j = W_j'h + b_j'
Sj=Wj′h+bj′ 计算,有
S 1 = W 1 ′ h + b 1 ′ = [ − 1.0059 3.0113 − 2.6677 4.1419 2.093 − 6.9661 − 0.6537 − 0.3203 − 1.7061 5.4778 ] S_1 = W_1'h + b_1' = \begin{bmatrix} -1.0059 \\ 3.0113 \\ -2.6677 \\ 4.1419 \\ 2.093 \\ -6.9661 \\ -0.6537 \\ -0.3203 \\ -1.7061 \\ 5.4778 \end{bmatrix} S1=W1′h+b1′= −1.00593.0113−2.66774.14192.093−6.9661−0.6537−0.3203−1.70615.4778
同理可计算得到
S
2
S_2
S2,
S
3
S_3
S3,
S
4
S_4
S4,计算过程如下。
S
2
=
W
2
′
h
+
b
2
′
=
[
−
0.1995
0.2864
1.3497
1.0048
0.0222
−
0.0100
−
0.9559
−
1.2183
0.6900
−
0.9212
−
1.1577
1.8249
0.9124
2.0037
−
0.7485
−
0.5625
−
0.0363
1.5817
−
1.5693
−
0.1934
−
1.1668
−
0.1391
−
0.0170
0.9890
−
0.4406
1.3206
−
0.3917
0.6355
−
1.6778
1.0459
−
0.1167
1.4402
0.7464
−
0.8687
−
0.8388
−
0.3726
1.2700
−
0.2894
0.2669
−
1.4134
]
[
−
0.5194
0.2748
3.2488
−
2.3701
]
+
[
−
0.825
−
0.1326
1.2399
0.6032
0.4683
1.0228
0.2179
1.4366
0.2373
−
0.0302
]
=
[
0.8106
0.0736
−
7.5685
−
0.4352
−
5.0315
−
2.6214
−
3.7044
−
4.5506
−
0.9403
5.9425
]
S_2 = W_2'h + b_2' = \begin{bmatrix} -0.1995 & 0.2864 & 1.3497 & 1.0048 \\ 0.0222 & -0.0100 & -0.9559 & -1.2183 \\ 0.6900 & -0.9212 & -1.1577 & 1.8249 \\ 0.9124 & 2.0037 & -0.7485 & -0.5625 \\ -0.0363 & 1.5817 & -1.5693 & -0.1934 \\ -1.1668 & -0.1391 & -0.0170 & 0.9890 \\ -0.4406 & 1.3206 & -0.3917 & 0.6355 \\ -1.6778 & 1.0459 & -0.1167 & 1.4402 \\ 0.7464 & -0.8687 & -0.8388 & -0.3726 \\ 1.2700 & -0.2894 & 0.2669 & -1.4134 \end{bmatrix} \begin{bmatrix} -0.5194 \\ 0.2748 \\ 3.2488 \\ -2.3701 \end{bmatrix} + \begin{bmatrix} -0.825 \\ -0.1326 \\ 1.2399 \\ 0.6032 \\ 0.4683 \\ 1.0228 \\ 0.2179 \\ 1.4366 \\ 0.2373 \\ -0.0302 \end{bmatrix} = \begin{bmatrix} 0.8106 \\ 0.0736 \\ -7.5685 \\ -0.4352 \\ -5.0315 \\ -2.6214 \\ -3.7044 \\ -4.5506 \\ -0.9403 \\ 5.9425 \end{bmatrix}
S2=W2′h+b2′=
−0.19950.02220.69000.9124−0.0363−1.1668−0.4406−1.67780.74641.27000.2864−0.0100−0.92122.00371.5817−0.13911.32061.0459−0.8687−0.28941.3497−0.9559−1.1577−0.7485−1.5693−0.0170−0.3917−0.1167−0.83880.26691.0048−1.21831.8249−0.5625−0.19340.98900.63551.4402−0.3726−1.4134
−0.51940.27483.2488−2.3701
+
−0.825−0.13261.23990.60320.46831.02280.21791.43660.2373−0.0302
=
0.81060.0736−7.5685−0.4352−5.0315−2.6214−3.7044−4.5506−0.94035.9425
S
3
=
W
3
′
h
+
b
3
′
=
[
−
1.5420
0.3780
0.0699
1.3272
0.7584
0.0500
−
0.0235
−
1.3326
1.0015
−
0.7262
0.8167
−
0.9975
0.1882
0.0296
0.5850
−
0.0327
1.0954
−
0.5287
−
1.0225
0.6692
0.3536
0.1077
−
1.2242
0.4579
−
0.3014
0.5080
−
0.2723
−
0.7572
0.4270
1.2200
−
1.8448
−
0.1106
2.2247
−
1.2139
0.4352
−
0.9266
−
0.9744
−
0.7343
0.7212
1.2966
]
[
−
0.5194
0.2748
3.2488
−
2.3701
]
+
[
−
2.6444
1.4572
−
0.1357
1.2635
−
0.9616
1.1074
−
0.8948
−
0.5762
0.0009
0.9404
]
=
[
−
7.3561
5.7478
6.5325
3.5469
−
5.751
−
4.4147
0.0358
−
6.9127
5.989
−
0.6921
]
S_3 = W_3'h + b_3' = \begin{bmatrix} -1.5420 & 0.3780 & 0.0699 & 1.3272 \\ 0.7584 & 0.0500 & -0.0235 & -1.3326 \\ 1.0015 & -0.7262 & 0.8167 & -0.9975 \\ 0.1882 & 0.0296 & 0.5850 & -0.0327 \\ 1.0954 & -0.5287 & -1.0225 & 0.6692 \\ 0.3536 & 0.1077 & -1.2242 & 0.4579 \\ -0.3014 & 0.5080 & -0.2723 & -0.7572 \\ 0.4270 & 1.2200 & -1.8448 & -0.1106 \\ 2.2247 & -1.2139 & 0.4352 & -0.9266 \\ -0.9744 & -0.7343 & 0.7212 & 1.2966 \end{bmatrix} \begin{bmatrix} -0.5194 \\ 0.2748 \\ 3.2488 \\ -2.3701 \end{bmatrix} + \begin{bmatrix} -2.6444 \\ 1.4572 \\ -0.1357 \\ 1.2635 \\ -0.9616 \\ 1.1074 \\ -0.8948 \\ -0.5762 \\ 0.0009 \\ 0.9404 \end{bmatrix} = \begin{bmatrix} -7.3561 \\ 5.7478 \\ 6.5325 \\ 3.5469 \\ -5.751 \\ -4.4147 \\ 0.0358 \\ -6.9127 \\ 5.989 \\ -0.6921 \end{bmatrix}
S3=W3′h+b3′=
−1.54200.75841.00150.18821.09540.3536−0.30140.42702.2247−0.97440.37800.0500−0.72620.0296−0.52870.10770.50801.2200−1.2139−0.73430.0699−0.02350.81670.5850−1.0225−1.2242−0.2723−1.84480.43520.72121.3272−1.3326−0.9975−0.03270.66920.4579−0.7572−0.1106−0.92661.2966
−0.51940.27483.2488−2.3701
+
−2.64441.4572−0.13571.2635−0.96161.1074−0.8948−0.57620.00090.9404
=
−7.35615.74786.53253.5469−5.751−4.41470.0358−6.91275.989−0.6921
S 4 = W 4 ′ h + b 4 ′ = [ − 0.9707 − 0.7539 0.2467 − 0.9193 − 0.6605 0.6702 − 2.3042 1.0746 − 0.5031 0.2229 0.4336 0.8506 2.1695 0.7953 0.1527 − 0.9290 0.9809 0.4570 − 0.4028 − 1.6490 0.7021 − 0.6065 − 1.8008 0.5552 0.4988 − 0.6534 − 0.0171 0.7181 − 1.5637 − 0.2477 2.0687 1.2274 − 0.6794 − 0.4706 − 0.5578 1.7551 − 0.6347 − 0.7285 0.9110 0.2523 ] [ − 0.5194 0.2748 3.2488 − 2.3701 ] + [ − 0.1529 0.6095 1.43 − 0.2554 − 1.8245 0.0909 − 1.7414 − 2.4202 0.0652 − 0.6424 ] = [ 3.0653 − 11.3551 0.0944 4.1118 2.0648 − 7.3483 − 3.4239 0.4448 − 7.3677 1.6833 ] S_4 = W_4'h + b_4' = \begin{bmatrix} -0.9707 & -0.7539 & 0.2467 & -0.9193 \\ -0.6605 & 0.6702 & -2.3042 & 1.0746 \\ -0.5031 & 0.2229 & 0.4336 & 0.8506 \\ 2.1695 & 0.7953 & 0.1527 & -0.9290 \\ 0.9809 & 0.4570 & -0.4028 & -1.6490 \\ 0.7021 & -0.6065 & -1.8008 & 0.5552 \\ 0.4988 & -0.6534 & -0.0171 & 0.7181 \\ -1.5637 & -0.2477 & 2.0687 & 1.2274 \\ -0.6794 & -0.4706 & -0.5578 & 1.7551 \\ -0.6347 & -0.7285 & 0.9110 & 0.2523 \end{bmatrix} \begin{bmatrix} -0.5194 \\ 0.2748 \\ 3.2488 \\ -2.3701 \end{bmatrix} + \begin{bmatrix} -0.1529 \\ 0.6095 \\ 1.43 \\ -0.2554 \\ -1.8245 \\ 0.0909 \\ -1.7414 \\ -2.4202 \\ 0.0652 \\ -0.6424 \end{bmatrix} = \begin{bmatrix} 3.0653 \\ -11.3551 \\ 0.0944 \\ 4.1118 \\ 2.0648 \\ -7.3483 \\ -3.4239 \\ 0.4448 \\ -7.3677 \\ 1.6833 \end{bmatrix} S4=W4′h+b4′= −0.9707−0.6605−0.50312.16950.98090.70210.4988−1.5637−0.6794−0.6347−0.75390.67020.22290.79530.4570−0.6065−0.6534−0.2477−0.4706−0.72850.2467−2.30420.43360.1527−0.4028−1.8008−0.01712.0687−0.55780.9110−0.91931.07460.8506−0.9290−1.64900.55520.71811.22741.75510.2523 −0.51940.27483.2488−2.3701 + −0.15290.60951.43−0.2554−1.82450.0909−1.7414−2.42020.0652−0.6424 = 3.0653−11.35510.09444.11182.0648−7.3483−3.42390.4448−7.36771.6833
于是我们得到
S
=
(
S
1
,
S
2
,
S
3
,
S
4
)
S = (S_1, S_2, S_3, S_4)
S=(S1,S2,S3,S4)
S
=
(
S
1
,
S
2
,
S
3
,
S
4
)
=
[
−
1.0059
0.8106
−
7.3561
3.0653
3.0113
0.0736
5.7478
−
11.3551
−
2.6677
−
7.5685
6.5325
0.0944
4.1419
−
0.4352
3.5469
4.1118
2.093
−
5.0315
−
5.751
2.0648
−
6.9661
−
2.6214
−
4.4147
−
7.3483
−
0.6537
−
3.7044
0.0358
−
3.4239
−
0.3203
−
4.5506
−
6.9127
0.4448
−
1.7061
−
0.9403
5.989
−
7.3677
5.4778
5.9425
−
0.6921
1.6833
]
S = (S_1, S_2, S_3, S_4) =\begin{bmatrix} -1.0059 & 0.8106 & -7.3561 & 3.0653 \\ 3.0113 & 0.0736 & 5.7478 & -11.3551 \\ -2.6677 & -7.5685 & 6.5325 & 0.0944 \\ 4.1419 & -0.4352 & 3.5469 & 4.1118 \\ 2.093 & -5.0315 & -5.751 & 2.0648 \\ -6.9661 & -2.6214 & -4.4147 & -7.3483 \\ -0.6537 & -3.7044 & 0.0358 & -3.4239 \\ -0.3203 & -4.5506 & -6.9127 & 0.4448 \\ -1.7061 & -0.9403 & 5.989 & -7.3677 \\ 5.4778 & 5.9425 & -0.6921 & 1.6833 \end{bmatrix}
S=(S1,S2,S3,S4)=
−1.00593.0113−2.66774.14192.093−6.9661−0.6537−0.3203−1.70615.47780.81060.0736−7.5685−0.4352−5.0315−2.6214−3.7044−4.5506−0.94035.9425−7.35615.74786.53253.5469−5.751−4.41470.0358−6.91275.989−0.69213.0653−11.35510.09444.11182.0648−7.3483−3.42390.4448−7.36771.6833
接下来是 Softmax 层,计算公式即
P
j
(
k
)
=
S
o
f
t
m
a
x
(
S
j
(
k
)
)
=
exp
(
S
j
(
k
)
)
∑
l
=
0
V
−
1
exp
(
S
j
(
l
)
)
P_j(k) = Softmax(S_j(k)) =\frac{\exp(S_j(k))}{ \sum\limits_{l=0}^{V-1} \exp(S_{j}(l))}
Pj(k)=Softmax(Sj(k))=l=0∑V−1exp(Sj(l))exp(Sj(k))计算过程如下:
以 P 1 P_1 P1为例计算,我们首先对 S 1 S_1 S1每一个元素都取 e x e^x ex,即 e S 1 ( k ) e^{S_{1}(k)} eS1(k),可以得到
S 1 ′ = [ S 1 ′ ( 0 ) S 1 ′ ( 1 ) S 1 ′ ( 2 ) S 1 ′ ( 3 ) S 1 ′ ( 4 ) S 1 ′ ( 5 ) S 1 ′ ( 6 ) S 1 ′ ( 7 ) S 1 ′ ( 8 ) S 1 ′ ( 9 ) ] T = [ e − 1.0059 e 3.0113 e − 2.6677 e 4.1419 e 2.093 e − 6.9661 e − 0.6537 e − 0.3203 e − 1.7061 e 5.4778 ] T = [ 0.3657 20.3137 0.0694 62.9222 8.1092 0.0009 0.5201 0.7259 0.1815 239.3196 ] T S_1'=\begin{bmatrix} S_{1}'(0) & S_{1}'(1) & S_{1}'(2) & S_{1}'(3) & S_{1}'(4) & S_{1}'(5) & S_{1}'(6) & S_{1}'(7) & S_{1}'(8) & S_{1}'(9) \end{bmatrix}^T =\begin{bmatrix} e^{-1.0059} & e^{3.0113} & e^{-2.6677} & e^{4.1419} & e^{2.093} & e^{-6.9661} & e^{-0.6537} & e^{-0.3203} & e^{-1.7061} & e^{5.4778} \end{bmatrix}^T = \begin{bmatrix}0.3657 & 20.3137 & 0.0694 & 62.9222 & 8.1092 & 0.0009 & 0.5201 & 0.7259 & 0.1815 & 239.3196\end{bmatrix}^T S1′=[S1′(0)S1′(1)S1′(2)S1′(3)S1′(4)S1′(5)S1′(6)S1′(7)S1′(8)S1′(9)]T=[e−1.0059e3.0113e−2.6677e4.1419e2.093e−6.9661e−0.6537e−0.3203e−1.7061e5.4778]T=[0.365720.31370.069462.92228.10920.00090.52010.72590.1815239.3196]T
对 S 1 ′ S_1' S1′中的每个元素 S 1 ′ ( k ) S_1'(k) S1′(k), k = 0 , 2 , … , V − 1 k = 0,2,\dots,V-1 k=0,2,…,V−1 进行求和得到 S u m 1 Sum_1 Sum1,计算过程为
S u m 1 = ∑ l = 0 V − 1 S 1 ′ ( l ) = e − 1.0059 + e 3.0113 + e − 2.6677 + e 4.1419 + e 2.093 + e − 6.9661 + e − 0.6537 + e − 0.3203 + e − 1.7061 + e 5.4778 = 0.3657 + 20.3137 + 0.0694 + 62.9222 + 8.1092 + 0.0009 + 0.5201 + 0.7259 + 0.1815 + 239.3196 = 332.5281 Sum_1 = \sum\limits_{l = 0}^{V-1} S_{1}'(l) \\ = e^{-1.0059} + e^{3.0113} + e^{-2.6677} + e^{4.1419} + e^{2.093} + e^{-6.9661} + e^{-0.6537} + e^{-0.3203} + e^{-1.7061} + e^{5.4778} \\ = 0.3657 + 20.3137 + 0.0694 + 62.9222 + 8.1092 + 0.0009 + 0.5201 + 0.7259 + 0.1815 + 239.3196 \\ = 332.5281 Sum1=l=0∑V−1S1′(l)=e−1.0059+e3.0113+e−2.6677+e4.1419+e2.093+e−6.9661+e−0.6537+e−0.3203+e−1.7061+e5.4778=0.3657+20.3137+0.0694+62.9222+8.1092+0.0009+0.5201+0.7259+0.1815+239.3196=332.5281
然后将
S
1
′
S_1'
S1′ 中的每个元素除以
S
u
m
1
Sum_1
Sum1 得到
P
1
P_1
P1,即
P
1
=
[
S
1
′
(
0
)
S
1
′
(
1
)
S
1
′
(
2
)
S
1
′
(
3
)
S
1
′
(
4
)
S
1
′
(
5
)
S
1
′
(
6
)
S
1
′
(
7
)
S
1
′
(
8
)
S
1
′
(
9
)
]
T
S
u
m
0
=
[
0.3657
20.3137
0.0694
62.9222
8.1092
0.0009
0.5201
0.7259
0.1815
239.3196
]
T
332.5281
=
[
0.001
0.061
0.0002
0.1892
0.0243
0
0.0015
0.0021
0.0005
0.7196
]
T
P_1 = \frac{\begin{bmatrix} S_{1}'(0) & S_{1}'(1) & S_{1}'(2) & S_{1}'(3) & S_{1}'(4) & S_{1}'(5) & S_{1}'(6) & S_{1}'(7) & S_{1}'(8) & S_{1}'(9) \end{bmatrix}^T}{Sum_0} =\frac{\begin{bmatrix}0.3657 & 20.3137 & 0.0694 & 62.9222 & 8.1092 & 0.0009 & 0.5201 & 0.7259 & 0.1815 & 239.3196\end{bmatrix}^T}{332.5281} =\begin{bmatrix}0.001 & 0.061 & 0.0002 & 0.1892 & 0.0243 & 0 & 0.0015 & 0.0021 & 0.0005 & 0.7196\end{bmatrix}^T
P1=Sum0[S1′(0)S1′(1)S1′(2)S1′(3)S1′(4)S1′(5)S1′(6)S1′(7)S1′(8)S1′(9)]T=332.5281[0.365720.31370.069462.92228.10920.00090.52010.72590.1815239.3196]T=[0.0010.0610.00020.18920.024300.00150.00210.00050.7196]T
同理可计算出 P 2 P_2 P2、 P 3 P_3 P3、 P 4 P_4 P4,计算过程如下。首先计算出 S 2 ′ S_2' S2′、 S 3 ′ S_3' S3′、 S 4 ′ S_4' S4′
S 2 ′ = [ S 2 ′ ( 0 ) S 2 ′ ( 1 ) S 2 ′ ( 2 ) S 2 ′ ( 3 ) S 2 ′ ( 4 ) S 2 ′ ( 5 ) S 2 ′ ( 6 ) S 2 ′ ( 7 ) S 2 ′ ( 8 ) S 2 ′ ( 9 ) ] T = [ e 0.8106 e 0.0736 e − 7.5685 e − 0.4352 e − 5.0315 e − 2.6214 e − 3.7044 e − 4.5506 e − 0.9403 e 5.9425 ] T = [ 2.2492 1.0763 0.0005 0.6471 0.0065 0.0727 0.0246 0.0105 0.3905 380.8859 ] T S_2' = \begin{bmatrix} S_{2}'(0) & S_{2}'(1) & S_{2}'(2) & S_{2}'(3) & S_{2}'(4) & S_{2}'(5) & S_{2}'(6) & S_{2}'(7) & S_{2}'(8) & S_{2}'(9) \end{bmatrix}^T =\begin{bmatrix} e^{0.8106} & e^{0.0736} & e^{-7.5685} & e^{-0.4352} & e^{-5.0315} & e^{-2.6214} & e^{-3.7044} & e^{-4.5506} & e^{-0.9403} & e^{5.9425} \end{bmatrix}^T = \begin{bmatrix} 2.2492 & 1.0763 & 0.0005 & 0.6471 & 0.0065 & 0.0727 & 0.0246 & 0.0105 & 0.3905 & 380.8859 \end{bmatrix}^T S2′=[S2′(0)S2′(1)S2′(2)S2′(3)S2′(4)S2′(5)S2′(6)S2′(7)S2′(8)S2′(9)]T=[e0.8106e0.0736e−7.5685e−0.4352e−5.0315e−2.6214e−3.7044e−4.5506e−0.9403e5.9425]T=[2.24921.07630.00050.64710.00650.07270.02460.01050.3905380.8859]T
S 3 ′ = [ S 3 ′ ( 0 ) S 3 ′ ( 1 ) S 3 ′ ( 2 ) S 3 ′ ( 3 ) S 3 ′ ( 4 ) S 3 ′ ( 5 ) S 3 ′ ( 6 ) S 3 ′ ( 7 ) S 3 ′ ( 8 ) S 3 ′ ( 9 ) ] T = [ e − 7.3561 e 5.7478 e 6.5325 e 3.5469 e − 5.751 e − 4.4147 e 0.0358 e − 6.9127 e 5.989 e − 0.6921 ] T = [ 0.0006 313.5002 687.1138 34.7055 0.0031 0.012 1.0364 0.0009 399.0153 0.5005 ] T S_3' = \begin{bmatrix} S_{3}'(0) & S_{3}'(1) & S_{3}'(2) & S_{3}'(3) & S_{3}'(4) & S_{3}'(5) & S_{3}'(6) & S_{3}'(7) & S_{3}'(8) & S_{3}'(9) \end{bmatrix}^T = \begin{bmatrix} e^{-7.3561} & e^{5.7478} & e^{6.5325} & e^{3.5469} & e^{-5.751} & e^{-4.4147} & e^{0.0358} & e^{-6.9127} & e^{5.989} & e^{-0.6921} \end{bmatrix}^T = \begin{bmatrix} 0.0006 & 313.5002 & 687.1138 & 34.7055 & 0.0031 & 0.012 & 1.0364 & 0.0009 & 399.0153 & 0.5005 \end{bmatrix}^T S3′=[S3′(0)S3′(1)S3′(2)S3′(3)S3′(4)S3′(5)S3′(6)S3′(7)S3′(8)S3′(9)]T=[e−7.3561e5.7478e6.5325e3.5469e−5.751e−4.4147e0.0358e−6.9127e5.989e−0.6921]T=[0.0006313.5002687.113834.70550.00310.0121.03640.0009399.01530.5005]T
S 4 ′ = [ S 4 ′ ( 0 ) S 4 ′ ( 1 ) S 4 ′ ( 2 ) S 4 ′ ( 3 ) S 4 ′ ( 4 ) S 4 ′ ( 5 ) S 4 ′ ( 6 ) S 4 ′ ( 7 ) S 4 ′ ( 8 ) S 4 ′ ( 9 ) ] T = [ e 3.0653 e − 11.3551 e 0.0944 e 4.1118 e 2.0648 e − 7.3483 e − 3.4239 e 0.4448 e − 7.3677 e 1.6833 ] T = [ 21.4408 0 1.0989 61.0565 7.8837 0.0006 0.0325 1.5601 0.0006 5.3832 ] T S_4' =\begin{bmatrix} S_{4}'(0) & S_{4}'(1) & S_{4}'(2) & S_{4}'(3) & S_{4}'(4) & S_{4}'(5) & S_{4}'(6) & S_{4}'(7) & S_{4}'(8) & S_{4}'(9) \end{bmatrix}^T = \begin{bmatrix} e^{3.0653} & e^{-11.3551} & e^{0.0944} & e^{4.1118} & e^{2.0648} & e^{-7.3483} & e^{-3.4239} & e^{0.4448} & e^{-7.3677} & e^{1.6833} \end{bmatrix}^T = \begin{bmatrix} 21.4408 & 0 & 1.0989 & 61.0565 & 7.8837 & 0.0006 & 0.0325 & 1.5601 & 0.0006 & 5.3832 \end{bmatrix}^T S4′=[S4′(0)S4′(1)S4′(2)S4′(3)S4′(4)S4′(5)S4′(6)S4′(7)S4′(8)S4′(9)]T=[e3.0653e−11.3551e0.0944e4.1118e2.0648e−7.3483e−3.4239e0.4448e−7.3677e1.6833]T=[21.440801.098961.05657.88370.00060.03251.56010.00065.3832]T
然后计算 S u m 2 Sum_2 Sum2、 S u m 3 Sum_3 Sum3、 S u m 4 Sum_4 Sum4,得到
S u m 2 = ∑ l = 0 V − 1 S 2 ′ ( l ) = e 0.8106 + e 0.0736 + e − 7.5685 + e − 0.4352 + e − 5.0315 + e − 2.6214 + e − 3.7044 + e − 4.5506 + e − 0.9403 + e 5.9425 = 2.2492 + 1.0763 + 0.0005 + 0.6471 + 0.0065 + 0.0727 + 0.0246 + 0.0105 + 0.3905 + 380.8859 = 385.3637 Sum_2 = \sum\limits_{l = 0}^{V-1} S_2'(l) \\ = e^{0.8106} + e^{0.0736} + e^{-7.5685} + e^{-0.4352} + e^{-5.0315} + e^{-2.6214} + e^{-3.7044} + e^{-4.5506} + e^{-0.9403} + e^{5.9425} \\ = 2.2492 + 1.0763 + 0.0005 + 0.6471 + 0.0065 + 0.0727 + 0.0246 + 0.0105 + 0.3905 + 380.8859 \\ = 385.3637 Sum2=l=0∑V−1S2′(l)=e0.8106+e0.0736+e−7.5685+e−0.4352+e−5.0315+e−2.6214+e−3.7044+e−4.5506+e−0.9403+e5.9425=2.2492+1.0763+0.0005+0.6471+0.0065+0.0727+0.0246+0.0105+0.3905+380.8859=385.3637
S u m 3 = ∑ l = 0 V − 1 S 3 ′ ( l ) = e − 7.3561 + e 5.7478 + e 6.5325 + e 3.5469 + e − 5.751 + e − 4.4147 + e 0.0358 + e − 6.9127 + e 5.989 + e − 0.6921 = 0.0006 + 313.5002 + 687.1138 + 34.7055 + 0.0031 + 0.012 + 1.0364 + 0.0009 + 399.0153 + 0.5005 = 1435.8883 Sum_3 = \sum\limits_{l = 0}^{V-1} S_{3}'(l) \\ = e^{-7.3561} + e^{5.7478} + e^{6.5325} + e^{3.5469} + e^{-5.751} + e^{-4.4147} + e^{0.0358} + e^{-6.9127} + e^{5.989} + e^{-0.6921} \\ = 0.0006 + 313.5002 + 687.1138 + 34.7055 + 0.0031 + 0.012 + 1.0364 + 0.0009 + 399.0153 + 0.5005 \\ = 1435.8883 Sum3=l=0∑V−1S3′(l)=e−7.3561+e5.7478+e6.5325+e3.5469+e−5.751+e−4.4147+e0.0358+e−6.9127+e5.989+e−0.6921=0.0006+313.5002+687.1138+34.7055+0.0031+0.012+1.0364+0.0009+399.0153+0.5005=1435.8883
S u m 4 = ∑ l = 0 V − 1 S 4 ′ ( l ) = e 3.0653 + e − 11.3551 + e 0.0944 + e 4.1118 + e 2.0648 + e − 7.3483 + e − 3.4239 + e 0.4448 + e − 7.3677 + e 1.6833 = 21.4408 + 0 + 1.0989 + 61.0565 + 7.8837 + 0.0006 + 0.0325 + 1.5601 + 0.0006 + 5.3832 = 98.4569 Sum_4 = \sum\limits_{l = 0}^{V-1} S_{4}'(l) \\ = e^{3.0653} + e^{-11.3551} + e^{0.0944} + e^{4.1118} + e^{2.0648} + e^{-7.3483} + e^{-3.4239} + e^{0.4448} + e^{-7.3677} + e^{1.6833} \\ = 21.4408 + 0 + 1.0989 + 61.0565 + 7.8837 + 0.0006 + 0.0325 + 1.5601 + 0.0006 + 5.3832 \\ = 98.4569 Sum4=l=0∑V−1S4′(l)=e3.0653+e−11.3551+e0.0944+e4.1118+e2.0648+e−7.3483+e−3.4239+e0.4448+e−7.3677+e1.6833=21.4408+0+1.0989+61.0565+7.8837+0.0006+0.0325+1.5601+0.0006+5.3832=98.4569
最后进行 P 2 P_2 P2、 P 3 P_3 P3、 P 4 P_4 P4的计算
P 2 = [ S 2 ′ ( 0 ) S 2 ′ ( 1 ) S 2 ′ ( 2 ) S 2 ′ ( 3 ) S 2 ′ ( 4 ) S 2 ′ ( 5 ) S 2 ′ ( 6 ) S 2 ′ ( 7 ) S 2 ′ ( 8 ) S 2 ′ ( 9 ) ] T S u m 1 = [ 2.2492 1.0763 0.0005 0.6471 0.0065 0.0727 0.0246 0.0105 0.3905 380.8859 ] T 385.3637 = [ 0.0057 0.0027 0 0.0016 0 0.0001 0 0 0.001 0.9883 ] T P_2 = \frac{\begin{bmatrix} S_{2}'(0) & S_{2}'(1) & S_{2}'(2) & S_{2}'(3) & S_{2}'(4) & S_{2}'(5) & S_{2}'(6) & S_{2}'(7) & S_{2}'(8) & S_{2}'(9) \end{bmatrix}^T}{Sum_1} \\ = \frac{\begin{bmatrix} 2.2492 & 1.0763 & 0.0005 & 0.6471 & 0.0065 & 0.0727 & 0.0246 & 0.0105 & 0.3905 & 380.8859 \end{bmatrix}^T}{385.3637} \\ =\begin{bmatrix} 0.0057 & 0.0027 & 0 & 0.0016 & 0 & 0.0001 & 0 & 0 & 0.001 & 0.9883 \end{bmatrix}^T P2=Sum1[S2′(0)S2′(1)S2′(2)S2′(3)S2′(4)S2′(5)S2′(6)S2′(7)S2′(8)S2′(9)]T=385.3637[2.24921.07630.00050.64710.00650.07270.02460.01050.3905380.8859]T=[0.00570.002700.001600.0001000.0010.9883]T
P 3 = [ S 3 ′ ( 0 ) S 3 ′ ( 1 ) S 3 ′ ( 2 ) S 3 ′ ( 3 ) S 3 ′ ( 4 ) S 3 ′ ( 5 ) S 3 ′ ( 6 ) S 3 ′ ( 7 ) S 3 ′ ( 8 ) S 3 ′ ( 9 ) ] T S u m 2 = [ 0.0006 313.5002 687.1138 34.7055 0.0031 0.012 1.0364 0.0009 399.0153 0.5005 ] T 1435.8883 = [ 0 0.2183 0.4785 0.0241 0 0 0.0007 0 0.2778 0.0002 ] T P_3 = \frac{\begin{bmatrix} S_{3}'(0) & S_{3}'(1) & S_{3}'(2) & S_{3}'(3) & S_{3}'(4) & S_{3}'(5) & S_{3}'(6) & S_{3}'(7) & S_{3}'(8) & S_{3}'(9) \end{bmatrix}^T}{Sum_2} \\ = \frac{\begin{bmatrix} 0.0006 & 313.5002 & 687.1138 & 34.7055 & 0.0031 & 0.012 & 1.0364 & 0.0009 & 399.0153 & 0.5005 \end{bmatrix}^T}{1435.8883} \\ = \begin{bmatrix} 0 & 0.2183 & 0.4785 & 0.0241 & 0 & 0 & 0.0007 & 0 & 0.2778 & 0.0002 \end{bmatrix}^T P3=Sum2[S3′(0)S3′(1)S3′(2)S3′(3)S3′(4)S3′(5)S3′(6)S3′(7)S3′(8)S3′(9)]T=1435.8883[0.0006313.5002687.113834.70550.00310.0121.03640.0009399.01530.5005]T=[00.21830.47850.0241000.000700.27780.0002]T
P 4 = [ S 4 ′ ( 0 ) S 4 ′ ( 1 ) S 4 ′ ( 2 ) S 4 ′ ( 3 ) S 4 ′ ( 4 ) S 4 ′ ( 5 ) S 4 ′ ( 6 ) S 4 ′ ( 7 ) S 4 ′ ( 8 ) S 4 ′ ( 9 ) ] T S u m 3 = [ 21.4408 0 1.0989 61.0565 7.8837 0.0006 0.0325 1.5601 0.0006 5.3832 ] T 98.4569 = [ 0.2177 0 0.0111 0.6201 0.08 0 0.0002 0.0158 0 0.0546 ] T P_4 = \frac{\begin{bmatrix} S_{4}'(0) & S_{4}'(1) & S_{4}'(2) & S_{4}'(3) & S_{4}'(4) & S_{4}'(5) & S_{4}'(6) & S_{4}'(7) & S_{4}'(8) & S_{4}'(9) \end{bmatrix}^T}{Sum_3} \\ = \frac{\begin{bmatrix} 21.4408 & 0 & 1.0989 & 61.0565 & 7.8837 & 0.0006 & 0.0325 & 1.5601 & 0.0006 & 5.3832 \end{bmatrix}^T}{98.4569} \\ =\begin{bmatrix} 0.2177 & 0 & 0.0111 & 0.6201 & 0.08 & 0 & 0.0002 & 0.0158 & 0 & 0.0546 \end{bmatrix}^T P4=Sum3[S4′(0)S4′(1)S4′(2)S4′(3)S4′(4)S4′(5)S4′(6)S4′(7)S4′(8)S4′(9)]T=98.4569[21.440801.098961.05657.88370.00060.03251.56010.00065.3832]T=[0.217700.01110.62010.0800.00020.015800.0546]T
于是我们得到 P = ( P 1 , P 2 , P 3 , P 4 ) P = (P_1, P_2, P_3, P_4) P=(P1,P2,P3,P4)
P = ( P 1 , P 2 , P 3 , P 4 ) = [ P 1 ( 0 ) P 2 ( 0 ) P 3 ( 0 ) P 4 ( 0 ) P 1 ( 1 ) P 2 ( 1 ) P 3 ( 1 ) P 4 ( 1 ) P 1 ( 2 ) P 2 ( 2 ) P 3 ( 2 ) P 4 ( 2 ) P 1 ( 3 ) P 2 ( 3 ) P 3 ( 3 ) P 4 ( 3 ) P 1 ( 4 ) P 2 ( 4 ) P 3 ( 4 ) P 4 ( 4 ) P 1 ( 5 ) P 2 ( 5 ) P 3 ( 5 ) P 4 ( 5 ) P 1 ( 6 ) P 2 ( 6 ) P 3 ( 6 ) P 4 ( 6 ) P 1 ( 7 ) P 2 ( 7 ) P 3 ( 7 ) P 4 ( 7 ) P 1 ( 8 ) P 2 ( 8 ) P 3 ( 8 ) P 4 ( 8 ) P 1 ( 9 ) P 2 ( 9 ) P 3 ( 9 ) P 4 ( 9 ) ] = [ 0.001 0.0057 0 0.2177 0.061 0.0027 0.2183 0 0.0002 0 0.4785 0.0111 0.1892 0.0016 0.0241 0.6201 0.0243 0 0 0.08 0 0.0001 0 0 0.0015 0 0.0007 0.0002 0.0021 0 0 0.0158 0.0005 0.001 0.2778 0 0.7196 0.9883 0.0002 0.0546 ] P = (P_1, P_2, P_3, P_4) = \begin{bmatrix} P_{1}(0) & P_{2}(0) & P_{3}(0) & P_{4}(0) \\ P_{1}(1) & P_{2}(1) & P_{3}(1) & P_{4}(1) \\ P_{1}(2) & P_{2}(2) & P_{3}(2) & P_{4}(2) \\ P_{1}(3) & P_{2}(3) & P_{3}(3) & P_{4}(3) \\ P_{1}(4) & P_{2}(4) & P_{3}(4) & P_{4}(4) \\ P_{1}(5) & P_{2}(5) & P_{3}(5) & P_{4}(5) \\ P_{1}(6) & P_{2}(6) & P_{3}(6) & P_{4}(6) \\ P_{1}(7) & P_{2}(7) & P_{3}(7) & P_{4}(7) \\ P_{1}(8) & P_{2}(8) & P_{3}(8) & P_{4}(8) \\ P_{1}(9) & P_{2}(9) & P_{3}(9) & P_{4}(9) \\ \end{bmatrix} = \begin{bmatrix} 0.001 & 0.0057 & 0 & 0.2177 \\ 0.061 & 0.0027 & 0.2183 & 0 \\ 0.0002 & 0 & 0.4785 & 0.0111 \\ 0.1892 & 0.0016 & 0.0241 & 0.6201 \\ 0.0243 & 0 & 0 & 0.08 \\ 0 & 0.0001 & 0 & 0 \\ 0.0015 & 0 & 0.0007 & 0.0002 \\ 0.0021 & 0 & 0 & 0.0158 \\ 0.0005 & 0.001 & 0.2778 & 0 \\ 0.7196 & 0.9883 & 0.0002 & 0.0546 \end{bmatrix} P=(P1,P2,P3,P4)= P1(0)P1(1)P1(2)P1(3)P1(4)P1(5)P1(6)P1(7)P1(8)P1(9)P2(0)P2(1)P2(2)P2(3)P2(4)P2(5)P2(6)P2(7)P2(8)P2(9)P3(0)P3(1)P3(2)P3(3)P3(4)P3(5)P3(6)P3(7)P3(8)P3(9)P4(0)P4(1)P4(2)P4(3)P4(4)P4(5)P4(6)P4(7)P4(8)P4(9) = 0.0010.0610.00020.18920.024300.00150.00210.00050.71960.00570.002700.001600.0001000.0010.988300.21830.47850.0241000.000700.27780.00020.217700.01110.62010.0800.00020.015800.0546
根据 P P P,我们得到 P 1 , P 2 , P 3 , P 4 P_1, P_2, P_3, P_4 P1,P2,P3,P4中概率最大的值分别为 0.7196 0.7196 0.7196、 0.9883 0.9883 0.9883、 0.4785 0.4785 0.4785、 0.6201 0.6201 0.6201,也就是索引位置在 9 9 9、 9 9 9、 2 2 2、 3 3 3位置的单词,对应的one-hot向量为 [ 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 1 ] T [0,0,0,0,0,0,0,0,0,1]^T [0,0,0,0,0,0,0,0,0,1]T、 [ 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 1 ] T [0,0,0,0,0,0,0,0,0,1]^T [0,0,0,0,0,0,0,0,0,1]T、 [ 0 , 0 , 1 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ] T [0,0,1,0,0,0,0,0,0,0]^T [0,0,1,0,0,0,0,0,0,0]T、 [ 0 , 0 , 0 , 1 , 0 , 0 , 0 , 0 , 0 , 0 ] T [0,0,0,1,0,0,0,0,0,0]^T [0,0,0,1,0,0,0,0,0,0]T,也就是单词 ‘.’(句号)、‘.’(句号)、‘plays’、‘in’。于是我们输出预测的上下文单词按顺序依次为 ‘.’(句号)、‘.’(句号)、‘plays’、‘in’。
下面继续依据模型结构中的例子,用来解释损失函数的计算。
损失函数
Skip-gram 模型的损失函数是与 CBOW 模型的损失函数一样,在模型结构中的 CrossEntropyError 模块中,使用交叉熵损失进行计算。交叉熵损失的计算公式如 \ref{equation02}所示。CrossEntropyError 的输入是 Softmax 层计算得到的概率向量 P j ( k ) P_j(k) Pj(k),和正确的监督标签 T T T ,其中 P j = ( P 1 , P 2 , … , P 2 C ) T P_j = (P_1, P_2, \dots, P_{2C})^T Pj=(P1,P2,…,P2C)T,正确的监督标签 T j = [ ( t j ( 1 ) , t j ( 2 ) , … , t j ( V ) ] T T_j = \left[(t_j(1), t_j(2), \dots, t_j(V)\right]^T Tj=[(tj(1),tj(2),…,tj(V)]T 就是正确答案单词的 one-hot 向量。计算出每一个预测单词的损失累加最后得到总损失,计算公式为
L j = − ∑ k = 1 V t k log ( P j ( k ) ) \text{L}_j = - \sum\limits_{k = 1}^{V} t_k\log(P_j(k)) Lj=−k=1∑Vtklog(Pj(k))
Loss = ∑ j = 1 2 C L j \text{Loss} = \sum\limits_{j = 1}^{2C} L_j Loss=j=1∑2CLj
我们使用上面第一个公式,计算交叉熵损失得到损失的结果,并将 2 × C 2 \times C 2×C个交叉熵损失的结果相加得到总的损失结果。
小结
Skip-gram 模型训练的基本步骤包括:
1.将目标单词进行 one-hot 表征作为模型的输入 x i x_i xi,其中词汇表的维度为 V V V,上下文单词数量为 C C C ;
2.然后将目标单词的 one-hot 向量乘以输入层到隐层的权重输入矩阵
W
W
W 再加上置偏值
b
b
b得到隐藏层向量
h
h
h,其中隐藏层的维数为
D
D
D。计算公式为
h
=
W
x
i
+
b
h = Wx_i + b
h=Wxi+b
3.将隐藏层向量
h
h
h 乘以隐藏层到输出层的权重输出矩阵
W
j
′
W_j'
Wj′ 再加上置偏值
b
j
′
b_j'
bj′ 得到多个预测的上下文词的得分向量
S
j
S_j
Sj,即
S
j
=
W
j
′
h
+
b
j
′
S_j = W_j'h + b_j'
Sj=Wj′h+bj′
其中共有
2
×
C
2 \times C
2×C对权重数组矩阵和置偏值,
j
j
j为需要预测的上下文单词的索引。
4.将计算得到的得分向量
S
j
S_j
Sj都通过 Softmax 激活处理得到
V
V
V 维的概率分布
P
i
P_i
Pi,即
P
j
=
Softmax
(
S
j
)
P_j = \text{Softmax}(S_j)
Pj=Softmax(Sj)
5.取概率最大的索引作为预测的上下文词;通过概率分布 P j P_j Pj和 one-hot 监督标签 T j T_j Tj 用交叉熵损失计算损失。
我们的目标是通过梯度下降让损失函数变小,使模型学习到如何根据上下文的信息推断出最可能的上下文词,训练结束得到的 W W W 或 W ′ W' W′ 作为训练的副产物就是我们的词向量(矩阵)。