1.tanh激活函数公式
tanh ( x ) = e x − e − x e x + e − x \tanh \left( x \right) =\frac{e^x-e^{-x}}{e^x+e^{-x}} \notag tanh(x)=ex+e−xex−e−x
2.基于JAX实现tanh激活函数
import jax.numpy as jnp
from jax import random
def tanh(x):
"""tanh function"""
return (jnp.exp(x)-jnp.exp(-x))/(jnp.exp(x)+jnp.exp(-x))
# 设置伪随机数种子
rng=random.PRNGKey(0)
# 标准正态采样得到输入向量
x=random.normal(rng,shape=(4,1))
# 调用内置tanh函数实现
print(jnp.tanh(x))
# 调用自定义tanh函数实现
print(tanh(x))
3.softmax函数公式
s i = e x i ∑ i = 0 j e x i s_i=\frac{e^{x_i}}{\sum\limits _{i=0}^{j}e^{x_i} } \notag si=i=0∑jexiexi
其中 x i x_i xi表示第 i i i个神经元实值输出
4.基于JAX实现softmax函数
import jax.numpy as jnp
import jax.nn as nn
def softmax(x,axis=-1):
"""softmax function"""
unnormalized=jnp.exp(x)
return unnormalized/jnp.sum(unnormalized)
# 定义数组
arr=jnp.arange(-2,4)
# 调用自定义softmax函数
print(softmax(arr))
# 调用jax自带softmax函数
print(nn.softmax(arr))
5.交叉熵函数公式
H ( p , q ) = − ∑ i = 1 n p ( x i ) log ( q ( x i ) ) H\left( p,q \right) =-\sum_{i=1}^n{p\left( x_i \right) \log \left( q\left( x_i \right) \right)} \notag H(p,q)=−i=1∑np(xi)log(q(xi))
其中 p ( x ) p(x) p(x)表示真实概率分布, q ( x ) q(x) q(x)表示预测概率分布
6.基于JAX实现交叉熵函数
import jax.numpy as jnp
def cross_entropy(y_true,y_pred,eps=1e-7):
"""cross entropy function
:param y_true:真实标签
:param y_pred:神经网络预测标签
:param eps:默认极小正数,保证对数真数不为0,增强log函数数值稳定性
:return:交叉熵,保留到小数点后4位
"""
y_true=jnp.array(y_true)
y_pred=jnp.array(y_pred)
res=-jnp.sum(y_true*jnp.log(y_pred+eps),axis=-1)
return jnp.round(res,4)
# 预测概率分布
y_pred=[0.1,0.05,0.6,0.0,0.05,0.1,0.0,0.1,0.0,0.0]
# 真实概率分布
y_true=[0,0,1,0,0,0,0,0,0,0]
# 交叉熵为0.5108
print(cross_entropy(y_true,y_pred))