吴恩达教授在Coursera课程Deeplearning关于矩阵维度总结
在Deeplearning这门课程第一周的第三节课:Shallow Neural Networks上,以两层神经网络分析 forward propagation和back propagation各计算公式的矩阵维度。
根据如下的神经网络架构:
1. Forward propagation
1.1 公式1
公式 |
Z[1]
Z
[
1
]
|
W[1]
W
[
1
]
|
X
X
|
b[1]
b
[
1
]
|
---|---|---|---|---|
维度 |
(n[1]×m)
(
n
[
1
]
×
m
)
|
(n[1]×n[0])
(
n
[
1
]
×
n
[
0
]
)
|
(n[0]×m)
(
n
[
0
]
×
m
)
|
(n[1]×m)
(
n
[
1
]
×
m
)
|
备注:
- m m 是样本数
1.2 公式2
公式 |
A[1]
A
[
1
]
|
Z[1]
Z
[
1
]
|
---|---|---|
维度 |
(n[1]×m)
(
n
[
1
]
×
m
)
|
(n[1]×m)
(
n
[
1
]
×
m
)
|
备注:
- 符号 ∗ ∗ 是矩阵点乘操作
- g[1] g [ 1 ] 可以是 sigmoid、tanh、或者Relu s i g m o i d 、 t a n h 、 或 者 R e l u 函数
1.3 公式3
公式 |
Z[2]
Z
[
2
]
|
W[2]
W
[
2
]
|
A[1]
A
[
1
]
|
b[2]
b
[
2
]
|
---|---|---|---|---|
维度 |
(n[2]×m)
(
n
[
2
]
×
m
)
|
(n[2]×n[1])
(
n
[
2
]
×
n
[
1
]
)
|
(n[1]×m)
(
n
[
1
]
×
m
)
|
(n[2]×m)
(
n
[
2
]
×
m
)
|
备注:
- 在神经网络中,每一组 W W 参数都是一行,与每一列 A A 相乘;在Logistic Regression中, W W 只有一列,是列向量,所以会写成 Z=WT X+b Z = W T X + b 的形式。
1.4 公式4
公式 |
A[1]
A
[
1
]
|
Z[1]
Z
[
1
]
|
---|---|---|
维度 |
(n[2]×m)
(
n
[
2
]
×
m
)
|
(n[2]×m)
(
n
[
2
]
×
m
)
|
备注:
- 符号 ∗ ∗ 是矩阵点乘操作
- g[2] g [ 2 ] 一般是 sigmoid s i g m o i d 函数(二分类)
2. Back propagation
2.1 公式5
公式 |
dZ[2]
d
Z
[
2
]
|
A[2]
A
[
2
]
|
Y
Y
|
---|---|---|---|
维度 |
(n[2]×m)
(
n
[
2
]
×
m
)
|
(n[2]×m)
(
n
[
2
]
×
m
)
|
(n[2]×m)
(
n
[
2
]
×
m
)
|
备注:
- A[2] A [ 2 ] 中的每一个元素在链式求导时都作为一个变量,所以维度是 (n[2]×m) ( n [ 2 ] × m )
- 如果是二分类,则 n[2]=1,Y=[y(1)y(2) ... y(m)] n [ 2 ] = 1 , Y = [ y ( 1 ) y ( 2 ) . . . y ( m ) ]
- 默认 g[2] g [ 2 ] 函数是 sigmoid s i g m o i d 函数
2.2 公式6
公式 |
dW[2]
d
W
[
2
]
|
dZ[2]
d
Z
[
2
]
|
A[1]T
A
[
1
]
T
|
---|---|---|---|
维度 |
(n[2]×n[1])
(
n
[
2
]
×
n
[
1
]
)
|
(n[2]×m)
(
n
[
2
]
×
m
)
|
(m×n[1])
(
m
×
n
[
1
]
)
|
备注:
- W[2] W [ 2 ] 中的每个元素都是 dZ[2] d Z [ 2 ] 中的每一行与 A[1]T A [ 1 ] T 中的每一列相乘的结果。每一个乘法的是意思的链式求导,相乘之后的加法意义是在计算 W[2] W [ 2 ] 每个参数的梯度时,将m个样本的loss相加,即: ∂Lost∂Wij=1m∑loss ∂ L o s t ∂ W i j = 1 m ∑ l o s s
2.3 公式7
公式 |
db[2]
d
b
[
2
]
|
---|---|
维度 |
(n[2]×1)
(
n
[
2
]
×
1
)
|
2.4 公式8
公式 |
dZ[1]
d
Z
[
1
]
|
W[2]T
W
[
2
]
T
|
dZ[2]
d
Z
[
2
]
|
g[1]′(Z[1])
g
[
1
]
′
(
Z
[
1
]
)
|
---|---|---|---|---|
维度 |
(n[1]×m)
(
n
[
1
]
×
m
)
|
(n[1]×n[2])
(
n
[
1
]
×
n
[
2
]
)
|
(n[2]×m)
(
n
[
2
]
×
m
)
|
(n[1]×m)
(
n
[
1
]
×
m
)
|
备注:
- 符号 ∗ ∗ 是矩阵点乘操作
- g[1] g [ 1 ] 一般是 sigmoid s i g m o i d 函数(二分类)
- 当 g[1]=sigmoid,g[1]′=g(1)(1−g(1)); g [ 1 ] = s i g m o i d , g [ 1 ] ′ = g ( 1 ) ( 1 − g ( 1 ) ) ;
- 当 g[1]=tanh,g[1]′=(1−g(1)2); g [ 1 ] = t a n h , g [ 1 ] ′ = ( 1 − g ( 1 ) 2 ) ;
当 g[1]=Relu,max(0,Z), g [ 1 ] = R e l u , m a x ( 0 , Z ) ,
g[1]′={0if Z < 01if Z ⩾ 0 g [ 1 ] ′ = { 0 i f Z < 0 1 i f Z ⩾ 0当 g[1]=Relu,max(0.01Z,Z), g [ 1 ] = R e l u , m a x ( 0.01 Z , Z ) ,
g[1]′={0.01if Z < 01if Z ⩾ 0 g [ 1 ] ′ = { 0.01 i f Z < 0 1 i f Z ⩾ 0
2.5 公式9
公式 |
dW[1]
d
W
[
1
]
|
dZ[1]
d
Z
[
1
]
|
XT
X
T
|
---|---|---|---|
维度 |
(n[1]×n[0])
(
n
[
1
]
×
n
[
0
]
)
|
(n[1]×m)
(
n
[
1
]
×
m
)
|
(m×n[0])
(
m
×
n
[
0
]
)
|
2.6 公式10
公式 |
db[1]
d
b
[
1
]
|
---|---|
维度 |
(n[1]×1)
(
n
[
1
]
×
1
)
|
3. 总结
公式书写习惯学习自吴恩达的课程,有错误地方,烦请指正。