一、激活函数及其梯度
一、简介
- 在神经网络中引入激活函数的目的是:我们的数据绝大部分都是非线性的,而在一般的神经网络中,计算都是线性的。而引入激活函数就是在神经网络中引入非线性,强化网络的学习能力。
二、常见的激活函数
1、Sigmoid函数
- 公式:
f
(
x
)
=
σ
(
x
)
=
1
1
+
e
−
x
f(x)=\sigma(x)=\frac{1}{1+e^{-x}}
f(x)=σ(x)=1+e−x1
- 导数:
f
(
x
)
′
=
e
−
x
(
1
+
e
−
x
)
2
=
f
(
x
)
∗
(
1
−
f
(
x
)
)
f(x)' = \frac{e^{-x}}{(1+e^{-x})^2}=f(x)*(1-f(x))
f(x)′=(1+e−x)2e−x=f(x)∗(1−f(x))
- 代码
a = torch.linspace(-100,100,10)
print(a)
print(torch.sigmoid(a))
# 输出:
# tensor([-100.0000, -77.7778, -55.5556, -33.3333, -11.1111, 11.1111,
# 33.3333, 55.5555, 77.7778, 100.0000])
# tensor([0.0000e+00, 1.6655e-34, 7.4564e-25, 3.3382e-15, 1.4945e-05, 9.9999e-01,
# 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00])
2、Tanh函数
- 公式:
f
(
x
)
=
t
a
n
h
(
x
)
=
e
x
−
e
−
x
e
x
+
e
−
x
=
2
∗
s
i
g
m
o
i
d
(
2
x
)
−
1
f(x)=tanh(x)=\frac{e^{x}-e^{-x}}{e^{x}+e^{-x}}=2*sigmoid(2x)-1
f(x)=tanh(x)=ex+e−xex−e−x=2∗sigmoid(2x)−1
- 导数:
t
a
n
h
(
x
)
′
=
1
−
(
e
x
−
e
−
x
)
2
(
e
x
+
e
−
x
)
2
=
1
−
t
a
n
h
(
x
)
2
tanh(x)'=1-\frac{(e^{x}-e^{-x})^2}{(e^{x}+e^{-x})^2}=1-tanh(x)^2
tanh(x)′=1−(ex+e−x)2(ex−e−x)2=1−tanh(x)2
- 代码
a = torch.linspace(-1,1,10)
print(a)
print(torch.tanh(a))
# 输出
# tensor([-1.0000, -0.7778, -0.5556, -0.3333, -0.1111, 0.1111, 0.3333, 0.5556,
# 0.7778, 1.0000])
# tensor([-0.7616, -0.6514, -0.5047, -0.3215, -0.1107, 0.1107, 0.3215, 0.5047,
# 0.6514, 0.7616])
3、ReLU函数
-
公式: f ( x ) = { 0 f o r x < 0 x f o r x ⩾ 0 f\left( x \right) =\begin{cases} 0\,\,\, for\,\,x\,\,<\,\,0\\ x\,\, for\,\,x\,\,\geqslant \,\,0\\ \end{cases} f(x)={0forx<0xforx⩾0
-
导数: f ( x ) = { 0 f o r x < 0 1 f o r x ⩾ 0 f\left( x \right) =\begin{cases} 0\,\,\, for\,\,x\,\,<\,\,0\\ 1\,\, for\,\,x\,\,\geqslant \,\,0\\ \end{cases} f(x)={0forx<01forx⩾0
-
代码
import torch
from torch.nn import functional as F
a = torch.linspace(-1,1,10)
print(a)
print(torch.relu(a))
print(F.relu(a)) #两种调用函数形式
# tensor([-1.0000, -0.7778, -0.5556, -0.3333, -0.1111, 0.1111, 0.3333, 0.5556,
# 0.7778, 1.0000])
# tensor([0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1111, 0.3333, 0.5556, 0.7778,
# 1.0000])
# tensor([0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1111, 0.3333, 0.5556, 0.7778,
# 1.0000])