百度飞桨(PaddlePaddle)如何使用mish激活函数
在YOLO-V4中,作者使用了一个新的深度学习激活函数 Mish。
(大概长成图上这个样子)
大家看到这个图像,是不是感觉很熟悉呢?这个函数图像和我们熟悉的ReLU长得很像,但是和ReLU相比,该函数保留了负数,而不是将负值全部丢弃。
Mish的梯度要比ReLU更加平滑,该函数在最终准确度上比Swish(+0.494%)和ReLU(+1.671%)都有提高。
然而这么好的激活函数,在百度的API文档里翻不出来…
既然没有,那就只能咱们自己动手了。首先来看Mish的函数表达式:
f ( x ) = x ∗ t a n h ( l n ( 1 + e x ) ) f(x) = x * tanh(ln(1+e^x)) f(x)=x∗tanh(ln(1+ex))
大家一看到数学表达式就容易晕。其实不用怕,虽然看不懂,但是不妨碍我们写代码。
首先最里面是1加上e的x次方,查阅API文档后可得到代码:
1 + fluid.layers.exp(x)
然后在这个外边套上ln,众所周知ln是log以e为底的简写:
fluid.layers.log(1 + fluid.layers.exp(x))
最后再套上tanh,前面再乘x,代码就完成了!是不是很easy
mish = x * fluid.layers.tanh(fluid.layers.log(1 + fluid.layers.exp(x)))
咱们写的这个对不对呢?咱们写个代码把图画出来看一下
import paddle.fluid as fluid
import numpy as np
import matplotlib.pyplot as plt
def Mish(x):
with fluid.dygraph.guard():
x = fluid.dygraph.to_variable(x)
mish = x * fluid.layers.tanh(fluid.layers.log(1 + fluid.layers.exp(x)))
return mish
if __name__ == '__main__':
in1 = np.array([-5 + i * 0.1 for i in range(100)])
out1 = Mish(in1)
plt.plot(in1, out1.numpy(), color='red')
plt.xticks([-5 + i for i in range(10)])
plt.yticks([-1 + i for i in range(6)])
# 获取当前的坐标轴, gca = get current axis
ax = plt.gca()
# 设置标题,也可用plt.title()设置
ax.set_title('Mish', fontsize=20, loc='left')
# 设置右边框和上边框,隐藏
ax.spines['right'].set_color('none')
ax.spines['top'].set_color('none')
# 设置x坐标轴为下边框
ax.xaxis.set_ticks_position('bottom')
# 设置y坐标轴为左边框
ax.yaxis.set_ticks_position('left')
# 设置x轴, y轴在(0, 0)的位置
ax.spines['bottom'].set_position(('data', 0))
ax.spines['left'].set_position(('data', 0))
plt.show()