- 正向传播x>0,反向传播时,则会将上游的值原封不动的传递给下游
- 正向传播x<0,反向传播时,则会将在此刻停止
# coding: utf-8
class Relu():
"""Relu函数,反向传播时,x>0则会将上游的值原封不动的传递给下游(dx = dout)
x<0则会将信号停在这里(dout=0)
先将输入数据转换为True和False的mask数组"""
def __init__(self):
self.mask = None # mask轮廓的含义,mask是由True/Fase组成的numpy数组。
def forward(self, x):
self.mask = (x <= 0) # mask会将x元素小于等于0的地方保存为True,其他地方都保存为False
out = x.copy() # False的地方输出为x
out[self.mask] = 0 # 将True的地方输出为0
return out
def backward(self, dout):
dout[self.mask] = 0 # 前面保存了mask,True的地方反向传播会停在这个地方,故TRUE的地方设置为0,False的地方是将上游的值原封不动的传递给下游
dx = dout
return dx