Back propagation(算偏导):
f
(
x
,
y
,
z
)
=
(
x
+
y
)
z
f(x,y,z)=(x+y)z
f(x,y,z)=(x+y)z
e.g. :
x
=
−
2
,
y
=
5
,
z
=
−
4
;
x = -2, y = 5, z = -4;
x=−2,y=5,z=−4;
q = x + y , ∂ q ∂ x = 1 , ∂ q ∂ y = 1 q = x+y,\frac{\partial q}{\partial x} = 1,\frac{\partial q}{\partial y} = 1 q=x+y,∂x∂q=1,∂y∂q=1
f = q z , ∂ f ∂ q = z , ∂ f ∂ z = q f = qz,\frac{\partial f}{\partial q} = z,\frac{\partial f}{\partial z} = q f=qz,∂q∂f=z,∂z∂f=q
Want : ∂ f ∂ x , ∂ f ∂ y , ∂ f ∂ z \frac{\partial f}{\partial x},\frac{\partial f}{\partial y},\frac{\partial f}{\partial z} ∂x∂f,∂y∂f,∂z∂f
对于反向传播都是从尾部依次往前求偏导,因此对于每一部分上的偏导可以用链式求导得到。
求
∂
f
∂
f
\frac{\partial f}{\partial f}
∂f∂f
求
∂
f
∂
q
\frac{\partial f}{\partial q}
∂q∂f
求
∂
f
∂
z
\frac{\partial f}{\partial z}
∂z∂f
求 ∂ f ∂ x \frac{\partial f}{\partial x} ∂x∂f
求 ∂ f ∂ y \frac{\partial f}{\partial y} ∂y∂f
计算图
伪代码
class ComputationGraph(object):
# 前向传播,得到计算图中的每个节点值
def forward(self, inputs):
# 1. [输入数据到节点中]
# 2. [把整个计算图中每个节点的值通过前向传播计算出来]
for gate in self.graph.nodes_topologically_sorted():
gate.forward()
return loss # 最终输出损失值
# 反向传播,得到最终变量在每个方向上的梯度值
def backward(self):
for gate in reseverse(self.graph.nodes_topologically_sorted()):
gate.backward()
return inputs_gradients # 最终输出损失值
例如:
class MultiplyGate(object):
def forward(self, x, y):
z = x * y
# 用以更新当前计算图中每个节点值
self.x = x
self.y = y
return z
def backward(self, dz):
dx = dz * self.y # [dL/dz * dz/dx]
dy = dz * self.x # [dL/dz * dz/dy]
return [dx, dy]