小白玩机器学习(6)--- 基于Tensorflow.js的在线手写数字识别

一、题目要求

1.三个js文件,分别完成:网络训练以及模型保存、模型加载及准确率测试、在线手写数字识别;

2.模型测试准确率要高于99.3%(尽量);

3.在线手写数字识别需要能够通过鼠标在画布中写入0~9数字,并进行实时识别,按空格键清除。测试需具有一定的准确性。

二、实验原理

   利用卷积神经网络提高数字识别结果的精度。

   假设图像的尺寸是28*28,那么如果我们在下一层有1000个单位,我们就需要学习28*28*1000个单位的权重。像素可能是相关的,因此构建了一个k*k核作为权重学习的过滤器。

   池化没有需要学习的变量。它的作用是对图像进行细分采样,使下一层可以查看更大的空间区域。进一步缩小网络范围,减少需要学习的参数量。

    如何进一步提高准确性?添加noise或dropout;添加更多层;使用更多的epochs和更大的batch size;在模型中添加卷积层,使用卷积神经网络强化准确率。

    保存并加载 tf.Model的方法:tf.Modeltf.Sequential同时提供了函数 model.save 允许您保存一个模型的拓扑结构和权重。IndexedDB (仅限浏览器):await model.save('indexeddb://my-model');这样会将模型保存到浏览器的IndexedDB存储中。与本地存储一样,它在刷新后仍然存在,同时它往往也对存储的对象的大小有较大的限制。(参考链接:https://blog.csdn.net/Aria_Miazzy/article/details/103793323

三、设计思路

1. 准备工作

下载MNIST数据集http://yann.lecun.com/exdb/mnist/

数据读取需要下载并保存为mnist.js文件:

https://github.com/CodingTrain/Toy-Neural-Network-JS/blob/master/examples/mnist/mnist.js

添加加载数据集的代码:

loadMNIST(function (data) {
        mnist = data;
        console.log(mnist);
    })

三个页面三个js文件分别进行:

begin.html对应train.js完成网络模型的训练以及模型保存;

Load.html对应load.js完成模型的加载以及准确率测试;

Recognition.html对应recognition.js完成手写体在线实时识别

2. 添加网络

根据上图添加神经网络:(train.js)

(1)添加卷积层,大小为28*28,其中卷积核大小为5,使用的激活函数为relu;(2)添加池化层,尺寸为2*2;

(2)添加卷积层,卷积核个数为5,激活函数为relu;

(3)添加池化层;

(4)为了提高准确路,在此处添加dropout,并且rate=0.5;

(5)降维后添加全连接层,激活函数为relu;

(6)使用adam()优化器并设置rate=0.002,损失函数为softmaxCrossEntrop;至此完成了网络的配置。

// 初始化模型
const model = tf.sequential();
// Convolutional layer 二维卷积层
model.add(tf.layers.conv2d({
  inputShape: [28, 28, 1],   // 1:颜色黑白
  kernelSize: 5,   // 卷积核大小为5
  filters: 16,     // 卷积核数量为16
  strides: 1,    // 步长为1
  activation: 'relu',    // 激活函数为relu
  kernelInitializer: 'varianceScaling'   // 初始化卷积核
}));
// 经过这层变化[28,28,1]-->[14,14,16]
// Pooling layer 二维池化层
model.add(tf.layers.maxPooling2d({
        poolSize: [2, 2],   // 尺寸
        strides: [2, 2]    // 步长
}));
// Convolutional layer 二维卷积层
model.add(tf.layers.conv2d({
  kernelSize: 5,   // 卷积核
  filters: 32,
  strides: 1,
  activation: 'relu',
  kernelInitializer: 'varianceScaling'
}));
// Pooling layer 池化层
model.add(tf.layers.maxPooling2d({
        poolSize: [2, 2],
        strides: [2, 2]
}));
// 添加dropout  rate = 0.5随机去掉一半
model.add(tf.layers.dropout({
    rate: 0.5
}));
// Flatten layer 降维
model.add(tf.layers.flatten());
// Dense layer 
model.add(tf.layers.dense({//全连接层
    units: 128,
    activation: 'relu'
}));
model.add(tf.layers.dense({
        units: 10,   // 对应0-10数字
    }));
const OPT = tf.train.adam(0.002)  // 优化器
    const config = {
    optimizer: OPT,
    loss: tf.losses.softmaxCrossEntropy, // 损失函数
}
model.compile(config);  //模型设置好配置

3. 加载数据

由于训练集数量比较大,这里选取了前60000个数据进行训练(train.js)

console.log("载入数据")
inputs = tf.tensor2d(mnist.train_images.slice(0, 60000));
outputs_org = tf.tensor1d(mnist.train_labels.slice(0, 60000));// 标签Y
outputs = tf.oneHot((outputs_org), 10);//全部对应到0-9  [0,0,0,0,0,0,0,1]
        
console.log("重组数据")  // 归一化除以255 变成0-1
inputs = tf.div(inputs, tf.scalar(255.0));
inputs = inputs.reshape([60000, 28, 28, 1]);  // 格式化28* 28* 1

4. 训练模型

这里使用15个epoh迭代,并且实时输出每一轮结果的loss.(train.js)

async function train() {
     for (let i = 1; i < 15; i++) {
         const h = await model.fit(inputs, outputs, {
                  atchSize: 200,
                  epochs: 1
                );
         console.log("Loss after Epoch " + i + " : " + h.history.loss[0]);
          }
         const saveResults = await model.save('indexeddb://my-model-6');
         console.log("模型已经保存");
         select('#modelStatus').html('模型已经训练完成并保存');
  }

5. 其中需要对模型进行保存和重加载

// 保存训练模型到浏览器数据库my-model-5
 const saveResults = await model.save('indexeddb://my-model');
// 加载已经保存的my-model模型,不需要重新训练
const model = await tf.loadLayersModel('indexeddb://my-model'); 

6. 测试训练准确率

首先加载测试数据,这里选择前10000个,之后进行训练(load.js)

console.log("加载测试数据。。")
        inputs_test = tf.tensor2d(mnist.test_images.slice(0, 10000));
        inputs_test = tf.div(inputs_test,tf.scalar(255.0));
        inputs_test = inputs_test.reshape([10000, 28, 28, 1]);
        outputs_test = tf.tensor1d(mnist.test_labels.slice(0, 10000));
        print(outputs_test.shape);
        console.log("测试数据加载完成")

async function test() {
       const model = await tf.loadLayersModel('indexeddb://my-model');
       console.log('加载已经保存的模型');
       output_tem = model.predict(inputs_test);
       label = tf.argMax(output_tem, 1);
       // 打印测试准确率
       tf.div(tf.sum(outputs_test.equal(label)), mnist.test_labels.length).print();
       result = tf.div(tf.sum(outputs_test.equal(label)), mnist.test_labels.length);
       select('#modelStatus').html('模型已经加载完成:' + result);
 }

7. 手写体识别可视化

实时鼠标在区域画数字,会进行预测,点击空格键删除。(recognition.js)

(参考链接:https://github.com/CodingTrain/Toy-Neural-Network-JS/blob/master/examples/mnist

let img = user_digit.get();
            if(!user_has_drawing) {
                return img;
            }
            let inputs = [];
            img.resize(28, 28);
            img.loadPixels();
            for (let i = 0; i < 784; i++) {
                inputs[i] = img.pixels[i * 4];
            }
            inputs = tf.tensor2d([inputs]);
            inputs = inputs.reshape([1,28,28,1]);
            let prediction = model.predict(inputs);
            let guess = tf.argMax(prediction,1);
            user_guess_ele.html(guess.dataSync());
            return img;

image(user_digit, 0, 0);
    // 鼠标控制画线,预测数字
    if (mouseIsPressed) {
        user_has_drawing = true;
        user_digit.stroke(255);
        user_digit.strokeWeight(16);
        user_digit.line(mouseX, mouseY, pmouseX, pmouseY);
 }

四、实验结果

1. 网络训练以及模型保存

运行页面加载模型并开始训练,显示每个epoch的loss值,迭代完成后模型保存,页面也显示‘模型已经训练完成并保存’:

2. 模型加载及准确率测试

模型保存完成之后,点击‘测试准确率’按钮,跳转到模型测试页面,加载测试数据并显示准确率。可见当前的准确率为99.39

3. 在线手写数字识别

  数据测试完成之后,点击‘开始手写识别’按钮,跳转到手写识别页面,可以随机用鼠标在电脑上画0-9的数值测试结果,猜测的数字会显示在下面,点击空格键重画。首先会显示“正在加载模型”。当模型加载好后会出现“模型已经加载完成”,之后可以进行手写识别,如下图:

五、总结提升

(1)使用Tensorflow.js构建深度模型。使用卷积神经网络提高准确率

(2)把数组数据转换成张量,把标签转换成一种热类型。转换绘图成28*28图像(img。调整大小(28、28))并将其平铺以供测试。

3三个js文件,分别完成:网络训练以及模型保存、模型加载及准确率测试、在线手写数字识别

(4)异步保存和加载模型(异步函数和等待)。

(5)如何进一步提高准确性?添加noise或dropout;添加更多层;使用更多的epochs和更大的batch size;在模型中添加卷积层,使用卷积神经网络强化准确率。

  • 0
    点赞
  • 13
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

小白Rachel

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值