tensorfow(六)基于tensorflow的手写数字识别

基于tensorflow的手写数字识别

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets('MNIST_data',one_hot=True)#载入数据集

batch_size = 100#每个批次的大小
n_batch = mnist.train.num_examples//batch_size#计算一共有多少个批次

x = tf.placeholder(tf.float32,[None,784])#定义两个变量
y = tf.placeholder(tf.float32,[None,10])

#构建一个简单的神经网络
W = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([10]))
prediction = tf.nn.softmax(tf.matmul(x,W)+b)

#二次代价函数
loss = tf.reduce_mean(tf.square(y-prediction))
#使用梯度下降法
train_step = tf.train.GradientDescentOptimizer(0.2).minimize(loss)

#初始化变量
init = tf.global_variables_initializer()

#结果存放在布尔型列表中
correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))#argmax返回一维张量中最大的值所在的位置
#求准确率
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))

with tf.Session() as sess:
    sess.run(init)
    for epoch in range(21):
        for batch in range(n_batch):
            batch_xs,batch_ys = mnist.train.next_batch(batch_size)
            sess.run(train_step,feed_dict={x:batch_xs,y:batch_ys})

        acc = sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels})
        print("Iter" + str(epoch) + ",Testing Accuracy" + str(acc))

运行结果:

Iter0,Testing Accuracy0.8318
Iter1,Testing Accuracy0.8697
Iter2,Testing Accuracy0.8817
Iter3,Testing Accuracy0.8878
Iter4,Testing Accuracy0.893
Iter5,Testing Accuracy0.897
Iter6,Testing Accuracy0.8997
Iter7,Testing Accuracy0.9017
Iter8,Testing Accuracy0.9043
Iter9,Testing Accuracy0.9046
Iter10,Testing Accuracy0.906
Iter11,Testing Accuracy0.9078
Iter12,Testing Accuracy0.9083
Iter13,Testing Accuracy0.9091
Iter14,Testing Accuracy0.9102
Iter15,Testing Accuracy0.9106
Iter16,Testing Accuracy0.9118
Iter17,Testing Accuracy0.9122
Iter18,Testing Accuracy0.9129
Iter19,Testing Accuracy0.9135
Iter20,Testing Accuracy0.9133

使用最基本的神经网络准确率可以达到0.9133,该神经网络中有诸多可以修改的地方以提高最终的准确率

  • 1
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
1. 引入依赖 首先需要引入 tensorflow.js 的依赖,可以通过以下方式引入: ```html <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@3.0.0/dist/tf.min.js"></script> <script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/mnist@2.0.1"></script> ``` 2. 创建画布 我们需要在页面中创建一个画布,用户可以在上面手写数字。代码如下: ```html <canvas id="canvas" width="280" height="280"></canvas> ``` 3. 加载模型 接下来,我们需要加载训练好的模型。MNIST 模型是一个用于手写数字识别的深度学习模型。我们可以通过以下代码加载模型: ```javascript const model = await tf.loadLayersModel("https://storage.googleapis.com/tfjs-models/tfjs/mnist_v3/model.json"); ``` 4. 预处理数据 在使用模型进行预测之前,需要将用户手写的数字转换为模型所需的格式。我们可以将画布上的像素数据转换为一个 28x28 的张量,并将其归一化到 0 到 1 的范围内。 ```javascript const canvas = document.getElementById("canvas"); const ctx = canvas.getContext("2d"); const imgData = ctx.getImageData(0, 0, canvas.width, canvas.height); const data = imgData.data; const input = []; for (let i = 0; i < data.length; i += 4) { input.push(data[i + 2] / 255); } const tensor = tf.tensor(input, [1, 28, 28, 1]); ``` 5. 进行预测 最后,我们可以将预处理后的数据输入到模型中进行预测。 ```javascript const output = model.predict(tensor); const predictions = output.dataSync(); console.log(predictions); ``` 6. 完整代码 ```html <!DOCTYPE html> <html> <head> <meta charset="utf-8" /> <title>Handwritten Digit Recognition</title> <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@3.0.0/dist/tf.min.js"></script> <script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/mnist@2.0.1"></script> </head> <body> <h1>Handwritten Digit Recognition</h1> <canvas id="canvas" width="280" height="280"></canvas> <button onclick="predictDigit()">Predict Digit</button> <script> async function predictDigit() { const model = await tf.loadLayersModel( "https://storage.googleapis.com/tfjs-models/tfjs/mnist_v3/model.json" ); const canvas = document.getElementById("canvas"); const ctx = canvas.getContext("2d"); const imgData = ctx.getImageData(0, 0, canvas.width, canvas.height); const data = imgData.data; const input = []; for (let i = 0; i < data.length; i += 4) { input.push(data[i + 2] / 255); } const tensor = tf.tensor(input, [1, 28, 28, 1]); const output = model.predict(tensor); const predictions = output.dataSync(); console.log(predictions); } </script> </body> </html> ``` 以上就是基于 tensorflow.js 的在线手写数字识别的实现方法。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值