import jax
import jax.numpy as jnp
from flax import linen as nn
from flax.training import train_state
from flax import struct
import optax
import tensorflow_datasets as tfds
import numpy as np
import tensorflow as tf
# 1. 数据预处理
def preprocess(image, label):
image = tf.image.resize(image, [32, 32]) # Resize to 32x32
image = tf.cast(image, tf.float32) / 255.0 # [0, 1]
image = (image - 0.5) / 0.5 # Normalize to [-1, 1]
return image, label
# 2. 加载数据集
def get_datasets():
train_ds, test_ds = tfds.load(
"cifar10",
split=["train", "test"],
as_supervised=True,
shuffle_files=True,
data_dir="./data"
)
train_ds = train_ds.map(preprocess).shuffle(10000).batch(128).prefetch(tf.data.AUTOTUNE)
test_ds = test_ds.map(preprocess).batch(128).prefetch(tf.data.AUTOTUNE)
return train_ds, test_ds
# 3. 模型定义 - 简化版本,先去掉BatchNorm避免batch_stats问题
class ImprovedCNN(nn.Module):
@nn.compact
def __call__(self, x, training: bool = False):
# 第一层卷积
x = nn.Conv(features=64, kernel_size=(3, 3), padding='SAME')(x)
x = nn.relu(x)
x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
# 第二层卷积
x = nn.Conv(features=128, kernel_size=(3, 3), padding='SAME')(x)
x = nn.relu(x)
x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
# 第三层卷积
x = nn.Conv(features=256, kernel_size=(3, 3), padding='SAME')(x)
x = nn.relu(x)
x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
# 全局平均池化
x = jnp.mean(x, axis=(1, 2))
# 全连接层
x = nn.Dense(features=512)(x)
x = nn.relu(x)
x = nn.Dropout(rate=0.3, deterministic=not training)(x)
# 输出层
x = nn.Dense(features=10)(x)
return x
# 自定义训练状态(已修正缩进)
@struct.dataclass
class CustomTrainState(train_state.TrainState):
dropout_rng: jax.Array
# 4. 损失函数 & 准确率(修正拼写错误)
def cross_entropy_loss(logits, labels):
one_hot_labels = jax.nn.one_hot(labels, num_classes=10) # 修正:one_hot 不是 one_nort
return optax.softmax_cross_entropy(logits, one_hot_labels).mean()
def compute_accuracy(logits, labels):
return jnp.mean(jnp.argmax(logits, axis=-1) == labels)
# 5. 训练状态初始化(修正拼写+语法错误)
def create_train_state(rng, learning_rate=0.01):
model = ImprovedCNN()
dummy_input = jnp.ones((1, 32, 32, 3), dtype=jnp.float32) # 修正:float32 不是 flogts2
rng, params_rng = jax.random.split(rng)
rng, dropout_rng = jax.random.split(rng)# 初始化参数
variables = model.init(params_rng, dummy_input, training=False)
params = variables['params']
# 使用学习率预热 + 余弦衰减(修正语法:补充分号)
warmup_steps = 500
total_steps = 5000
schedule = optax.join_schedules(
[
optax.linear_schedule(
init_value=0.0,
end_value=learning_rate,
transition_steps=warmup_steps
),
optax.cosine_decay_schedule(
init_value=learning_rate,
decay_steps=total_steps - warmup_steps
)
],
boundaries=[warmup_steps]
)
tx = optax.adamw(learning_rate=schedule, weight_decay=1e-4) # 修正:adamw 不是 edem
return CustomTrainState.create(
apply_fn=model.apply,
params=params,
tx=tx,
dropout_rng=dropout_rng
)
# 6. 训练和评估 Step (JIT编译)
@jax.jit
def train_step(state, batch):
def loss_fn(params):
# 为每次前向传播拆分新的 RNG
dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
# 使用 rngs 参数传递 dropout RNG
logits = state.apply_fn(
{'params': params},
batch[0],
training=True,
rngs={'dropout': dropout_rng}
)
loss = cross_entropy_loss(logits, batch[1])
return loss, (logits, new_dropout_rng)
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(loss, (logits, new_dropout_rng)), grads = grad_fn(state.params)
state = state.apply_gradients(grads=grads)
# 更新 dropout_rng
state = state.replace(dropout_rng=new_dropout_rng)
accuracy = compute_accuracy(logits, batch[1])
return state, loss, accuracy
@jax.jit
def eval_step(state, batch):
# 评估时不需要 dropout
logits = state.apply_fn({'params': state.params}, batch[0], training=False)
loss = cross_entropy_loss(logits, batch[1])
accuracy = compute_accuracy(logits, batch[1])
return loss, accuracy
# 7. 主训练循环
def main():
rng = jax.random.PRNGKey(42)
rng, init_rng = jax.random.split(rng)
train_ds, test_ds = get_datasets()
state = create_train_state(init_rng)
num_epochs = 10
for epoch in range(num_epochs):
# Train
train_losses, train_accuracies = [], []
for batch in tfds.as_numpy(train_ds):
state, loss, acc = train_step(state, batch)
train_losses.append(loss)
train_accuracies.append(acc)
avg_loss = np.mean(train_losses)
avg_acc = np.mean(train_accuracies)
print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {avg_loss:.4f}, Train Acc: {avg_acc:.4f}")
# Eval
test_losses, test_accuracies = [], []
for batch in tfds.as_numpy(test_ds):
loss, acc = eval_step(state, batch)
test_losses.append(loss)
test_accuracies.append(acc)
test_acc = np.mean(test_accuracies)
print(f"Test Accuracy: {test_acc:.4f}")
print("Training completed!")
# 8. 运行
if __name__ == "__main__":
main()
把第1次训练的训练准确率提高到50%以上。把优化了的代码给我。
最新发布