JAX计算SeLU函数

1.SeLU(scaled exponential linear units)激活函数计算公式

selu ( x ) = λ { x  if  x > 0 α e x − α  if  x ⩽ 0. \text{selu}(x)= \lambda \begin{cases} x& \text{ if } x>0 \\ \alpha e^x-\alpha & \text{ if } x\leqslant 0. \end{cases} selu(x)=λ{xαexα if x>0 if x0.

其中 λ = 1.0507009873554804934193349852946 \lambda=1.0507009873554804934193349852946 λ=1.0507009873554804934193349852946 α = 1.6732632423543772848170429916717. \alpha=1.6732632423543772848170429916717. α=1.6732632423543772848170429916717.

2.JAX代码实现

#!/usr/bin/python3
# -*- coding: utf-8 -*-
"""
@Time        : 2022/7/20 13:51
@Author      : Albert Darren
@Contact     : 2563491540@qq.com
@File        : Program1.1.py
@Version     : Version 1.0.0
@Description : TODO 利用jax计算selu函数,详见P12
@Created By  : PyCharm
"""
import jax.numpy as jnp  # 导入numpy计算包
from jax import random  # 导入random随机数包


def selu(x, alpha=1.6732632423543772848170429916717, lmbda=1.0507009873554804934193349852946):
    """
    实现selu激活函数
    :param x: 输入张量
    :param alpha: 预定义参数alpha
    :param lmbda: 预定义参数lambda,此处变量名故意拼写错误,避免与关键字lambda命名冲突
    :return: selu函数值
    """
    return lmbda * jnp.where(x > 0, x, alpha * (jnp.exp(x) - 1))


# 产生一个固定数字17作为key
key = random.PRNGKey(17)
# 随机生成一个大小为[1,5]的矩阵
x = random.normal(key, (5,))
print(selu(x))
# [-1.2497659   0.4546819   1.5760192  -0.81573856  0.27510932]

3.参考文献

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值