JAX 来构建一个基本的人工神经网络(ANN)进行分类任务

本文介绍如何使用JAX库在Python中构建一个简单的全连接神经网络进行图像分类,包括模型定义、参数初始化、损失函数、优化器和训练过程。同时展示了如何将训练好的模型应用到FlaskWeb应用中进行实时预测。
摘要由CSDN通过智能技术生成
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
from jax.experimental import optimizers
from jax.nn import relu, softmax

# 构建神经网络模型
def neural_network(params, x):
    for W, b in params:
        x = jnp.dot(x, W) + b
        x = relu(x)
    return softmax(x)

# 初始化参数
def init_params(rng, layer_sizes):
    keys = random.split(rng, len(layer_sizes))
    return [(random.normal(k, (m, n)), random.normal(k, (n,))) 
            for k, (m, n) in zip(keys, zip(layer_sizes[:-1], layer_sizes[1:]))]

# 定义损失函数
def cross_entropy_loss(params, batch):
    inputs, targets = batch
    preds = neural_network(params, inputs)
    return -jnp.mean(jnp.sum(preds * targets, axis=1))

# 初始化优化器
def init_optimizer(params):
    return optimizers.adam(init_params)

# 更新参数
@jit
def update(params, batch, opt_state):
    grads = grad(cross_entropy_loss)(params, batch)
    updates, opt_state = opt.update(grads, opt_state)
    return opt_params, opt_state

# 训练函数
def train(rng, params, data, num_epochs=10, batch_size=32):
    opt_init, opt_update, get_params = init_optimizer(params)
    opt_state = opt_init(params)
    
    num_batches = len(data) // batch_size
    
    for epoch in range(num_epochs):
        rng, subrng = random.split(rng)
        for batch_idx in range(num_batches):
            batch = get_batch(data, batch_idx, batch_size)
            params = update(params, batch, opt_state)
        
        train_loss = cross_entropy_loss(params, batch)
        print(f"Epoch {epoch+1}, Loss: {train_loss}")
    
    return get_params(opt_state)

# 评估函数
def evaluate(params, data):
    inputs, targets = data
    preds = neural_network(params, inputs)
    accuracy = jnp.mean(jnp.argmax(preds, axis=1) == jnp.argmax(targets, axis=1))
    return accuracy

# 示例数据集和参数
rng = random.PRNGKey(0)
input_size = 784
num_classes = 10
layer_sizes = [input_size, 128, num_classes]
params = init_params(rng, layer_sizes)
opt = init_optimizer(params)

# 使用数据集进行训练
trained_params = train(rng, params, data)

# 评估模型
accuracy = evaluate(trained_params, test_data)
print("Test Accuracy:", accuracy)

理解如何使用 JAX 或其他深度学习库构建人工智能(AI)系统需要一定的学习和实践。下面我给你一个简单的例子来说明如何使用 JAX 来构建一个基本的人工神经网络(ANN)进行分类任务。

首先,让我们假设你想解决一个简单的图像分类问题,例如手写数字识别。我们将使用一个基本的全连接神经网络来实现这个任务。

这只是一个简单的示例,用于说明如何使用 JAX 来构建神经网络进行图像分类任务。实际情况下,你可能需要更复杂的网络结构、更大规模的数据集以及更多的训练技巧来实现更好的性能。继续学习和实践将帮助你更好地理解如何构建 AI 系统。

要生成并存储模型文件,你可以使用 joblib 库,就像之前保存模型一样。以下是评估模型并保存模型的代码示例:

python
import joblib

# 评估模型
accuracy = evaluate(trained_params, test_data)
print("Test Accuracy:", accuracy)

# 将训练好的模型保存为文件
joblib.dump(trained_params, 'trained_model.pkl')


此代码评估了训练好的模型在测试数据集上的准确率,并将模型保存为名为 trained_model.pkl 的文件。在此之后,你可以将 trained_model.pkl 文件用于部署模型或在其他地方进行预测。

让我们假设你已经训练了一个模型来识别手写数字。现在,我将展示如何结合手写图片应用并输出识别结果。我们将使用 Python 的 Flask 框架来构建一个简单的 Web 应用,并在用户上传手写数字图片后,使用训练好的模型进行预测。

首先,确保你已经安装了 Flask:

bash

pip install flask


然后,你可以创建一个名为 app.py 的 Python 脚本,其中包含以下内容:

python
from flask import Flask, render_template, request
from PIL import Image
import numpy as np
import joblib

app = Flask(__name__)

# 加载训练好的模型
model = joblib.load('trained_model.pkl')

@app.route('/')
def index():
    return render_template('index.html')

@app.route('/predict', methods=['POST'])
def predict():
    # 获取上传的图片文件
    file = request.files['file']
    
    # 将上传的图片转换为灰度图像并缩放为 28x28 像素
    img = Image.open(file).convert('L').resize((28, 28))
    
    # 将图像数据转换为 numpy 数组
    img_array = np.array(img) / 255.0  # 将像素值缩放到 [0, 1] 范围内
    
    # 将图像数据扁平化成一维数组
    img_flat = img_array.flatten()
    
    # 使用模型进行预测
    prediction = model.predict([img_flat])[0]
    
    return render_template('predict.html', prediction=prediction)

if __name__ == '__main__':
    app.run(debug=True)


上述代码创建了一个基本的 Flask 应用,包括两个路由:

- / 路由用于渲染主页,其中包含一个表单,允许用户上传手写数字图片。
- /predict 路由用于接收上传的图片并使用模型进行预测。

接下来,你需要创建两个 HTML 模板文件 index.html 和 predict.html,并放置在名为 templates 的文件夹中。index.html 用于渲染主页,而 predict.html 用于显示预测结果。

index.html 内容如下:

html
<!DOCTYPE html>
<html lang="en">
<head>
    <meta charset="UTF-8">
    <meta name="viewport" content="width=device-width, initial-scale=1.0">
    <title>Handwritten Digit Recognition</title>
</head>
<body>
    <h1>Handwritten Digit Recognition</h1>
    <form action="/predict" method="post" enctype="multipart/form-data">
        <input type="file" name="file" accept="image/*">
        <button type="submit">Predict</button>
    </form>
</body>
</html>

现在,你可以运行应用:

bash

python app.py


然后在浏览器中访问 http://localhost:5000/,上传手写数字图片并查看预测结果。

JAX中,可以使用`jax.nn`模块来构建神经网络。下面是一个示例代码,用于构建一个具有2个隐藏层的神经网络,并检查其所有权重和偏置: ```python import jax import jax.numpy as jnp from jax import random, nn # 定义神经网络模型 def neural_network(params, x): for w, b in params: x = nn.relu(jnp.dot(x, w) + b) return x # 初始化随机参数 key = random.PRNGKey(0) input_shape = (10,) # 输入形状为(10,) hidden_units = [20, 30] # 隐藏层单元数为[20, 30] output_units = 1 # 输出层单元数为1 # 随机初始化权重和偏置 layer_sizes = [input_shape[0]] + hidden_units + [output_units] keys = random.split(key, len(layer_sizes)) params = [(random.normal(k, (m, n)), random.normal(k, (n,))) for k, m, n in zip(keys, layer_sizes[:-1], layer_sizes[1:])] # 构建神经网络 def model(x): return neural_network(params, x) # 检查神经网络的所有权重和偏置 weights = [w for w, _ in params] biases = [b for _, b in params] print("权重:") for i, w in enumerate(weights): print(f"层 {i+1} 的权重:\n{w}") print("\n偏置:") for i, b in enumerate(biases): print(f"层 {i+1} 的偏置:\n{b}") ``` 这段代码首先定义了一个`neural_network`函数,用于计算神经网络的前向传播。然后,使用`random.normal`函数初始化了神经网络的权重和偏置。最后,通过遍历参数列表,打印出了每一层的权重和偏置。 希望对你有帮助!如果有任何问题,请随时提问。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值