提出问题
有如下1000个样本和标签:
样本序号 | 1 | 2 | 3 | … | 1000 |
---|---|---|---|---|---|
x1 | 0.0091867 | 0.10245588 | -0.41033773 | … | -0.20625644 |
x2 | 0.00666677 | 0.20947882 | 0.18172314 | … | 0.19683694 |
y | 1 | 2 | 3 | … | 2 |
还好这个数据只有两个特征,所以我们可以用可视化的方法展示,如下图:
定义神经网络结构
- 输入层两个特征值x1, x2
- 隐层8x2的权重矩阵和8x1的偏移矩阵
- 隐层由8个神经元构成
- 输出层有3个神经元负责3分类,使用Softmax函数进行分类
前向计算
单样本矩阵运算过程:
W
1
(
8
×
2
)
⋅
X
(
2
×
1
)
+
B
1
(
8
×
1
)
=
>
Z
1
(
8
×
1
)
W_1^{(8 \times 2)} \cdot X^{(2 \times 1)} + B_1^{(8 \times 1)} => Z_1^{(8 \times 1)}
W1(8×2)⋅X(2×1)+B1(8×1)=>Z1(8×1)
S
i
g
m
o
i
d
(
Z
1
)
=
>
A
1
(
8
×
1
)
Sigmoid(Z1) => A_1^{(8 \times 1)}
Sigmoid(Z1)=>A1(8×1)
W
2
(
3
×
8
)
×
A
1
(
8
×
1
)
+
B
2
(
3
×
1
)
=
>
Z
2
(
3
×
1
)
W_2^{(3 \times 8)} \times A_1^{(8 \times 1)} + B_2^{(3 \times 1)} => Z_2^{(3 \times 1)}
W2(3×8)×A1(8×1)+B2(3×1)=>Z2(3×1)
S
o
f
t
m
a
x
(
Z
2
)
=
>
A
2
(
3
×
1
)
Softmax(Z2) => A_2^{(3 \times 1)}
Softmax(Z2)=>A2(3×1)
损失函数
使用多分类交叉熵损失函数:
J ( w , b ) = − 1 m ∑ i = 1 m ∑ j = 1 n y i j ln ( a i j ) J(w,b) = -{1 \over m} \sum^m_{i=1} \sum^n_{j=1} y_{ij} \ln (a_{ij}) J(w,b)=−m1i=1∑mj=1∑nyijln(aij)
m为样本数,n为类别数。
可以简写为:
J = − Y ln A J = -Y \ln A J=−YlnA
反向传播
∂
J
∂
A
2
∂
A
2
∂
Z
2
=
A
2
−
Y
=
>
d
Z
2
\frac{\partial{J}}{\partial{A2}} \frac{\partial{A2}}{\partial{Z2}} = A2-Y => dZ2
∂A2∂J∂Z2∂A2=A2−Y=>dZ2
虽然这个求导结果和二分类一样,但是过程截然不同,详情请看6.4。
后续的梯度求解与9.1节一样,只拷贝结论在这里:
(2) d W 2 = d Z 2 × A 1 T dW2=dZ2 \times A1^T \tag{2} dW2=dZ2×A1T(2)
(3) d B 2 = d Z 2 dB2=dZ2 \tag{3} dB2=dZ2(3)
(4) W 2 T × d Z 2 ⊙ A 1 ⊙ ( 1 − A 1 ) = > d Z 1 W2^T \times dZ2 \odot A1 \odot (1-A1) => dZ1 \tag{4} W2T×dZ2⊙A1⊙(1−A1)=>dZ1(4)
(5) d W 1 = d Z 1 ⋅ X T dW1= dZ1 \cdot X^T \tag{5} dW1=dZ1⋅XT(5)
(6)
d
B
1
=
d
Z
1
dB1= dZ1 \tag{6}
dB1=dZ1(6)
迭代了10000次,没有到底损失函数小于0.06的条件。
分类结果图示:
多分类的工作原理
使用以下参数测试:
- eta = 0.1
- batch_size = 10
- n_hidden = 3
- eps = 0.005
如果隐层只使用2个神经元,只能得到近似的线性结果,如下图:
所以,隐层必须用3个神经元以上。以下是结果:
多分类损失函数值 | 分类结果(待优化) |
---|---|
https://github.com/microsoft/ai-edu/blob/master/B-教学案例与实践/B6-神经网络基本原理简明教程/11.2-理解多分类的工作原理.md
https://github.com/microsoft/ai-edu/blob/master/B-教学案例与实践/B6-神经网络基本原理简明教程/11.1-非线性多分类.md