前置知识
矩阵求导
https://blog.csdn.net/qq_39942341/article/details/128739604?spm=1001.2014.3001.5502
(看微分那部分就够了)
回归
设
X
∈
R
B
×
m
,
W
1
∈
R
n
×
m
,
1
∈
R
n
×
1
,
b
1
∈
R
1
×
n
,
Y
1
∈
R
B
×
n
\mathbf{X}\in \mathbb{R}^{B\times m},\mathbf{W}_1\in \mathbb{R}^{n\times m},\mathbf{1}\in \mathbb{R}^{n\times1},\mathbf{b}_1\in\mathbb{R}^{1\times n},\mathbf{Y}_1\in\mathbb{R}^{B\times n}
X∈RB×m,W1∈Rn×m,1∈Rn×1,b1∈R1×n,Y1∈RB×n
W
2
∈
R
p
×
n
,
b
1
∈
R
1
×
p
,
Y
2
∈
R
B
×
p
\mathbf{W}_2\in \mathbb{R}^{p\times n},\mathbf{b}_1\in\mathbb{R}^{1\times p},\mathbf{Y}_2\in\mathbb{R}^{B\times p}
W2∈Rp×n,b1∈R1×p,Y2∈RB×p
σ
(
⋅
)
\sigma\left(\cdot\right)
σ(⋅)是激活函数,例如sigmoid
Y
1
=
X
W
1
T
+
1
b
1
A
1
=
σ
(
Y
1
)
Y
2
=
A
1
W
2
T
+
1
b
2
A
2
=
σ
(
Y
2
)
l
=
1
2
m
s
e
(
A
,
A
2
)
=
1
2
∥
A
−
A
2
∥
F
2
\mathbf{Y}_1 = \mathbf{X}\mathbf{W}_1^T + \mathbf{1}\mathbf{b}_1\\ \mathbf{A}_1 = \sigma\left(\mathbf{Y}_1\right)\\ \mathbf{Y}_2 = \mathbf{A}_1\mathbf{W}_2^T +\mathbf{1}\mathbf{b}_2\\ \mathbf{A}_2 = \sigma\left(\mathbf{Y}_2\right)\\ l = \frac{1}{2}mse\left(\mathbf{A},\mathbf{A}_2\right) = \frac{1}{2}\|\mathbf{A}-\mathbf{A}_2\|_F^2
Y1=XW1T+1b1A1=σ(Y1)Y2=A1W2T+1b2A2=σ(Y2)l=21mse(A,A2)=21∥A−A2∥F2
∂
l
∂
A
2
=
A
2
−
A
\frac{\partial l}{\partial \mathbf{A}_2} = \mathbf{A}_2 - \mathbf{A}
∂A2∂l=A2−A
d
l
=
t
r
(
∂
l
∂
A
2
T
d
A
2
)
=
t
r
(
∂
l
∂
A
2
T
d
σ
(
Y
2
)
)
=
t
r
(
∂
l
∂
A
2
T
σ
′
(
Y
2
)
d
Y
2
)
=
t
r
(
(
∂
l
∂
A
2
⊙
σ
′
(
Y
2
)
)
T
d
Y
2
)
=
t
r
(
∂
l
∂
Y
2
T
d
Y
2
)
\begin{aligned} \rm{d}l &= tr\left(\frac{\partial l}{\partial \mathbf{A}_2}^T \rm{d}\mathbf{A}_2\right)\\ &=tr\left(\frac{\partial l}{\partial \mathbf{A}_2}^T \rm{d}\sigma\left(\mathbf{Y}_2\right)\right)\\ &=tr\left(\frac{\partial l}{\partial \mathbf{A}_2}^T\sigma^\prime\left(\mathbf{Y}_2\right) \rm{d}\mathbf{Y}_2\right) \\ &= tr\left(\left(\frac{\partial l}{\partial \mathbf{A}_2}\odot\sigma^\prime\left(\mathbf{Y}_2\right) \right)^T\rm{d}\mathbf{Y}_2\right) \\ &= tr\left(\frac{\partial l}{\partial \mathbf{Y}_2}^T\rm{d}\mathbf{Y}_2\right) \end{aligned}
dl=tr(∂A2∂lTdA2)=tr(∂A2∂lTdσ(Y2))=tr(∂A2∂lTσ′(Y2)dY2)=tr((∂A2∂l⊙σ′(Y2))TdY2)=tr(∂Y2∂lTdY2)
因此
∂
l
∂
Y
2
=
∂
l
∂
A
2
⊙
σ
′
(
Y
2
)
\frac{\partial l}{\partial \mathbf{Y}_2} = \frac{\partial l}{\partial \mathbf{A}_2}\odot\sigma^\prime\left(\mathbf{Y}_2\right)
∂Y2∂l=∂A2∂l⊙σ′(Y2)
d
l
=
t
r
(
∂
l
∂
Y
2
T
d
Y
2
)
=
t
r
(
∂
l
∂
Y
2
T
d
(
A
1
W
2
T
+
1
b
2
)
)
=
t
r
(
∂
l
∂
Y
2
T
(
d
A
1
)
W
2
T
)
+
t
r
(
∂
l
∂
Y
2
T
A
1
(
d
W
2
T
)
)
+
t
r
(
∂
l
∂
Y
2
T
1
d
(
d
b
2
)
)
=
t
r
(
W
2
T
∂
l
∂
Y
2
T
(
d
A
1
)
)
+
t
r
(
(
d
W
2
T
)
∂
l
∂
Y
2
T
A
1
)
+
t
r
(
∂
l
∂
Y
2
T
1
d
(
d
b
2
)
)
\begin{aligned} \rm{d}l &= tr\left(\frac{\partial l}{\partial \mathbf{Y}_2}^T\rm{d}\mathbf{Y}_2\right)\\ &= tr\left(\frac{\partial l}{\partial \mathbf{Y}_2}^T\rm{d}\left(\mathbf{A}_1\mathbf{W}_2^T +\mathbf{1}\mathbf{b}_2\right)\right)\\ &= tr\left(\frac{\partial l}{\partial \mathbf{Y}_2}^T\left(\rm{d}\mathbf{A}_1\right)\mathbf{W}_2^T\right) + tr\left(\frac{\partial l}{\partial \mathbf{Y}_2}^T\mathbf{A}_1\left(\rm{d}\mathbf{W}_2^T\right)\right) + tr\left(\frac{\partial l}{\partial \mathbf{Y}_2}^T\mathbf{1}\rm{d}\left(\rm{d}\mathbf{b}_2\right)\right)\\ &= tr\left(\mathbf{W}_2^T\frac{\partial l}{\partial \mathbf{Y}_2}^T\left(\rm{d}\mathbf{A}_1\right)\right) + tr\left(\left(\rm{d}\mathbf{W}_2^T\right)\frac{\partial l}{\partial \mathbf{Y}_2}^T\mathbf{A}_1\right) + tr\left(\frac{\partial l}{\partial \mathbf{Y}_2}^T\mathbf{1}\rm{d}\left(\rm{d}\mathbf{b}_2\right)\right)\\ \end{aligned}
dl=tr(∂Y2∂lTdY2)=tr(∂Y2∂lTd(A1W2T+1b2))=tr(∂Y2∂lT(dA1)W2T)+tr(∂Y2∂lTA1(dW2T))+tr(∂Y2∂lT1d(db2))=tr(W2T∂Y2∂lT(dA1))+tr((dW2T)∂Y2∂lTA1)+tr(∂Y2∂lT1d(db2))
因此
∂
l
∂
A
1
=
∂
l
∂
Y
2
W
2
∂
l
∂
W
2
=
∂
l
∂
Y
2
T
A
1
∂
l
∂
b
2
=
1
T
∂
l
∂
Y
2
\frac{\partial l}{\partial \mathbf{A}_1} = \frac{\partial l}{\partial \mathbf{Y}_2}\mathbf{W}_2\\ \frac{\partial l}{\partial \mathbf{W}_2} = \frac{\partial l}{\partial \mathbf{Y}_2}^T\mathbf{A}_1\\ \frac{\partial l}{\partial \mathbf{b}_2} =\mathbf{1}^T\frac{\partial l}{\partial \mathbf{Y}_2}\\
∂A1∂l=∂Y2∂lW2∂W2∂l=∂Y2∂lTA1∂b2∂l=1T∂Y2∂l
同理
∂
l
∂
Y
1
=
∂
l
∂
A
1
⊙
σ
′
(
Y
1
)
∂
l
∂
W
2
=
∂
l
∂
Y
1
T
X
∂
l
∂
b
1
=
1
T
∂
l
∂
Y
1
\frac{\partial l}{\partial \mathbf{Y}_1} = \frac{\partial l}{\partial \mathbf{A}_1}\odot\sigma^\prime\left(\mathbf{Y}_1\right)\\ \frac{\partial l}{\partial \mathbf{W}_2} = \frac{\partial l}{\partial \mathbf{Y}_1}^T\mathbf{X}\\ \frac{\partial l}{\partial \mathbf{b}_1} =\mathbf{1}^T\frac{\partial l}{\partial \mathbf{Y}_1}\\
∂Y1∂l=∂A1∂l⊙σ′(Y1)∂W2∂l=∂Y1∂lTX∂b1∂l=1T∂Y1∂l
如果采用sigmoid,则
σ
′
(
X
)
=
σ
(
X
)
(
1
−
σ
(
X
)
)
\sigma^{\prime}\left(\mathbf{X}\right) =\sigma\left(\mathbf{X}\right)\left(1-\sigma\left(\mathbf{X}\right)\right)
σ′(X)=σ(X)(1−σ(X))
如果采用relu,则
[
σ
′
(
X
)
]
i
j
=
{
1
,
X
i
j
>
0
0
,
o
t
h
e
r
w
i
s
e
\left[\sigma^{\prime}\left(\mathbf{X}\right)\right]_{ij} =\begin{cases} 1,X_{ij}>0\\ 0, otherwise \end{cases}
[σ′(X)]ij={1,Xij>00,otherwise
代码验证
#!/usr/bin/env python
# _*_ coding:utf-8 _*_
import torch
from torch import nn
def sigmoid_derivative(Y):
return Y * (1 - Y)
def relu_derivative(Y):
return torch.where(Y > 0, 1, 0)
if __name__ == '__main__':
B, m, n, p = 3, 5, 4, 6
linear1 = nn.Linear(m, n)
active1 = nn.Sigmoid()
derivative_1 = sigmoid_derivative
linear2 = nn.Linear(n, p)
active2 = nn.ReLU()
derivative_2 = relu_derivative
A = torch.randn(B, p)
X = torch.randn(B, m, requires_grad=True)
Y1 = linear1(X)
A1 = active1(Y1)
Y2 = linear2(A1)
A2 = active2(Y2)
# 1/2 mse(A2, A)
l = torch.sum((A2 - A) ** 2) * 0.5
l.backward()
grad_A2 = A2 - A
grad_Y2 = grad_A2 * derivative_2(A2)
grad_W2 = torch.mm(grad_Y2.T, A1)
grad_b2 = torch.mm(torch.ones(B, 1).T, grad_Y2)
print(torch.allclose(grad_W2, linear2.weight.grad))
print(torch.allclose(grad_b2, linear2.bias.grad))
grad_A1 = torch.mm(grad_Y2, linear2.weight)
grad_Y1 = grad_A1 * derivative_1(A1)
grad_W1 = torch.mm(grad_Y1.T, X)
grad_b1 = torch.mm(torch.ones(B, 1).T, grad_Y1)
print(torch.allclose(grad_W1, linear1.weight.grad))
print(torch.allclose(grad_b1, linear1.bias.grad))
分类
对于行向量
a
∈
R
1
×
n
\mathbf{a}\in\mathbb{R}^{1\times n}
a∈R1×n
s
o
f
t
m
a
x
(
a
)
=
e
a
e
a
1
n
softmax\left(\mathbf{a}\right) = \frac{e^{\mathbf{a}}}{e^{\mathbf{a}}\mathbf{1}_n}
softmax(a)=ea1nea
其中
1
n
∈
R
n
\mathbf{1}_n\in\mathbb{R}^n
1n∈Rn,为全1向量
设
y
∈
R
1
×
n
\mathbf{y}\in\mathbb{R}^{1\times n}
y∈R1×n只有一个元素为1,其他元素为0
交叉熵
c
e
(
a
,
y
)
=
−
log
(
s
o
f
t
m
a
x
(
a
)
)
y
T
=
−
(
a
−
log
(
e
a
1
n
)
1
n
T
)
y
T
=
−
a
y
T
+
log
(
e
a
1
n
)
\begin{aligned} ce\left(\mathbf{a},\mathbf{y}\right) &= -\log\left(softmax\left(\mathbf{a}\right)\right)\mathbf{y}^T\\ &= -\left(\mathbf{a}-\log \left(e^{\mathbf{a}}\mathbf{1}_n\right)\mathbf{1}_n^T\right)\mathbf{y}^T\\ &= -\mathbf{a}\mathbf{y}^T+\log\left(e^{\mathbf{a}}\mathbf{1}_n\right) \end{aligned}
ce(a,y)=−log(softmax(a))yT=−(a−log(ea1n)1nT)yT=−ayT+log(ea1n)
求导
d
l
=
t
r
(
−
(
d
a
)
y
T
+
1
e
a
1
n
(
e
a
⊙
d
a
)
1
n
)
=
t
r
(
−
(
d
a
)
y
T
+
1
e
a
1
n
(
1
n
T
)
T
(
e
a
⊙
d
a
)
)
=
t
r
(
−
(
d
a
)
y
T
+
1
e
a
1
n
(
1
n
T
⊙
e
a
)
T
(
d
a
)
)
=
t
r
(
−
(
d
a
)
y
T
+
1
e
a
1
n
(
e
a
)
T
(
d
a
)
)
=
t
r
(
−
y
T
(
d
a
)
+
(
s
o
f
t
m
a
x
(
a
)
)
T
(
d
a
)
)
\begin{aligned} \rm{d}l &= tr\left(-\left(\rm{d}\mathbf{a}\right) \mathbf{y}^T + \frac{1}{e^{\mathbf{a}}\mathbf{1}_n}\left(e^{\mathbf{a}}\odot\rm{d} \mathbf{a}\right)\mathbf{1}_n\right)\\ &= tr\left(-\left(\rm{d}\mathbf{a}\right) \mathbf{y}^T + \frac{1}{e^{\mathbf{a}}\mathbf{1}_n}\left(\mathbf{1}_n^T\right)^T\left(e^{\mathbf{a}}\odot\rm{d} \mathbf{a}\right)\right)\\ &= tr\left(-\left(\rm{d}\mathbf{a}\right) \mathbf{y}^T + \frac{1}{e^{\mathbf{a}}\mathbf{1}_n}\left(\mathbf{1}_n^T\odot e^{\mathbf{a}}\right)^T\left(\rm{d} \mathbf{a}\right)\right)\\ &= tr\left(-\left(\rm{d}\mathbf{a}\right) \mathbf{y}^T + \frac{1}{e^{\mathbf{a}}\mathbf{1}_n}\left(e^{\mathbf{a}} \right)^T\left(\rm{d} \mathbf{a}\right)\right)\\ &= tr\left(-\mathbf{y}^T\left(\rm{d}\mathbf{a}\right) + \left(softmax\left(\mathbf{a}\right) \right)^T\left(\rm{d} \mathbf{a}\right)\right)\\ \end{aligned}
dl=tr(−(da)yT+ea1n1(ea⊙da)1n)=tr(−(da)yT+ea1n1(1nT)T(ea⊙da))=tr(−(da)yT+ea1n1(1nT⊙ea)T(da))=tr(−(da)yT+ea1n1(ea)T(da))=tr(−yT(da)+(softmax(a))T(da))
于是
∂
l
∂
a
=
s
o
f
t
m
a
x
(
a
)
−
y
\frac{\partial l}{\partial \mathbf{a}} = softmax\left(\mathbf{a}\right)-\mathbf{y}
∂a∂l=softmax(a)−y
设
A
∈
R
B
×
n
,
Y
∈
R
B
×
n
\mathbf{A}\in\mathbb{R}^{B\times n},\mathbf{Y} \in\mathbb{R}^{B\times n}
A∈RB×n,Y∈RB×n,
其中
Y
\mathbf{Y}
Y每行只有一个元素为1,其他元素为0
设
a
i
\mathbf{a}_i
ai表示
A
\mathbf{A}
A第
i
i
i行
设
y
i
\mathbf{y}_i
yi表示
Y
\mathbf{Y}
Y第
i
i
i行
s
o
f
t
m
a
x
(
A
)
=
(
s
o
f
t
m
a
x
(
a
1
)
s
o
f
t
m
a
x
(
a
2
)
⋮
s
o
f
t
m
a
x
(
a
B
)
)
softmax\left(\mathbf{A}\right) = \begin{pmatrix} softmax\left(\mathbf{a}_1\right)\\ softmax\left(\mathbf{a}_2\right)\\ \vdots\\ softmax\left(\mathbf{a}_B\right)\\ \end{pmatrix}
softmax(A)=
softmax(a1)softmax(a2)⋮softmax(aB)
设
1
B
∈
R
B
\mathbf{1}_{B}\in\mathbb{R}^B
1B∈RB,为全1向量
c
e
(
A
,
Y
)
=
∑
i
=
1
B
c
e
(
a
i
,
y
i
)
=
1
B
T
log
(
e
A
1
n
)
−
t
r
(
A
Y
T
)
ce\left(\mathbf{A},\mathbf{Y}\right) = \sum_{i=1}^{B}ce\left(\mathbf{a}_i,\mathbf{y}_i\right) = \mathbf{1}_B^T\log\left(e^{\mathbf{A}}\mathbf{1}_n\right)-tr\left(\mathbf{A}\mathbf{Y}^T\right)
ce(A,Y)=∑i=1Bce(ai,yi)=1BTlog(eA1n)−tr(AYT)
求导得
∂
l
∂
A
=
(
∂
l
∂
a
1
∂
l
∂
a
2
⋮
∂
l
∂
a
B
)
=
(
s
o
f
t
m
a
x
(
a
1
)
−
y
1
s
o
f
t
m
a
x
(
a
2
)
−
y
2
⋮
s
o
f
t
m
a
x
(
a
B
)
−
y
B
)
=
s
o
f
t
m
a
x
(
A
)
−
Y
\frac{\partial l}{\partial \mathbf{A}} = \begin{pmatrix} \frac{\partial l}{\partial \mathbf{a}_1}\\ \frac{\partial l}{\partial \mathbf{a}_2}\\ \vdots\\ \frac{\partial l}{\partial \mathbf{a}_B}\\ \end{pmatrix} = \begin{pmatrix} softmax\left(\mathbf{a}_1\right) - \mathbf{y}_1\\ softmax\left(\mathbf{a}_2\right)- \mathbf{y}_2\\ \vdots\\ softmax\left(\mathbf{a}_B\right)- \mathbf{y}_B\\ \end{pmatrix}=softmax\left(\mathbf{A}\right)-\mathbf{Y}
∂A∂l=
∂a1∂l∂a2∂l⋮∂aB∂l
=
softmax(a1)−y1softmax(a2)−y2⋮softmax(aB)−yB
=softmax(A)−Y
验证:
#!/usr/bin/env python
# _*_ coding:utf-8 _*_
import torch
from torch import nn
import torch.nn.functional as F
if __name__ == '__main__':
B, n = 3, 4
ce = nn.CrossEntropyLoss(reduction='sum')
target = torch.empty(B, dtype=torch.long).random_(n)
target_one_hot = F.one_hot(target, num_classes=n)
A = torch.randn(B, n, requires_grad=True)
l = ce(A, target)
l.backward()
ones_B = torch.ones(B, 1)
ones_n = torch.ones(n, 1)
output = torch.mm(ones_B.T, torch.log(torch.mm(torch.exp(A), ones_n))) - (
torch.mm(A, target_one_hot.T.float())).trace()
print(torch.allclose(output, l))
grad_A = F.softmax(A, dim=1) - target_one_hot
print(torch.allclose(grad_A, A.grad))