optimizers.adam()
jax.experimental.optimizers.adam()返回三个参数,包含初始优化器状态、更新优化器状态、获取当前参数的函数.
语法:
opt_init, opt_update, get_params = optimizers.adam(lr)
opt_init:初始化优化器状态的函数. 它接受初始参数并返回一个优化器状态.
opt_update:更新优化器状态的函数. 它接受一个步数、梯度和当前的优化器状态,并返回更新后的优化器状态.
get_params:从优化器状态中提取当前参数的函数.
例如:
from jax.experimental import optimizers
import jax.numpy as jnp
import jax
# 假设我们有一个简单的线性模型和损失函数
def model(params, x):
return jnp.dot(x, params)
def loss_fn(params, x, y):
preds = model(params, x)
return jnp.mean((preds - y) ** 2)
# 初始化参数
params = jnp.array([0.0, 0.0]) # 初始权重
# 设置学习率
lr = 0.001
# 使用 Adam 优化器
opt_init, opt_update, get_params = optimizers.adam(lr) #使用指定的学习率初始化Adam优化器
# 初始化优化器状态
opt_state = opt_init(params)
# 定义更新步骤
@jax.jit #step函数使用jax.jit加速
def step(step, opt_state, x, y):
params = get_params(opt_state) #计算当前状态的参数
grads = jax.grad(loss_fn)(params, x, y) #计算当前参数的梯度
opt_state = opt_update(step, grads, opt_state) #更新优化器状态
return opt_state
# 示例数据
x = jnp.array([[1.0, 2.0], [3.0, 4.0]]) # 输入数据
y = jnp.array([1.0, 2.0]) # 真实值
# 训练循环
num_steps = 100
for step_num in range(num_steps):
opt_state = step(step_num, opt_state, x, y)
# 获取训练后的参数
trained_params = get_params(opt_state) #最后,从优化器状态中提取训练后的参数'trained_params'
print(trained_params)