【tensorflow.js学习笔记(2)】CNN识别手写数字集MNIST

本文档介绍了如何使用tensorflow.js构建卷积神经网络(CNN)来识别MNIST手写数字数据集。内容包括定义MNIST数据类、构建CNN模型、设置训练参数、模型训练过程以及可视化过程中遇到的错误及其解决方案。最后提供了完整项目的GitHub链接和运行步骤。
摘要由CSDN通过智能技术生成

笔记(1)中利用tensorflow.js完成了机器学习中曲线拟合的任务,这篇笔记将实现一个经典的机器学习问题——CNN识别手写数字集MNIST。参考官方示例Training on Images: Recognizing Handwritten Digits with a Convolutional Neural Network,修改部分代码并用echarts改写vega。

1、定义mnist数据类

import * as tf from '@tensorflow/tfjs';

const IMAGE_SIZE = 784;//图片大小28*28
const NUM_CLASSES = 10;//类别数
const NUM_DATASET_ELEMENTS = 65000;//总样本数
const NUM_TRAIN_ELEMENTS = 55000;//训练样本数
const NUM_TEST_ELEMENTS = NUM_DATASET_ELEMENTS - NUM_TRAIN_ELEMENTS;//测试样本数

const MNIST_IMAGES_SPRITE_PATH = './src/mnist_images.png';//mnist图像
const MNIST_LABELS_PATH = './src/mnist_labels_uint8';//mnist图像对应的类别

export class MnistData {
  constructor() {
    this.shuffledTrainIndex = 0;
    this.shuffledTestIndex = 0;
  }

  async load() {
    const img = new Image();
    const canvas = document.createElement('canvas');
    const ctx = canvas.getContext('2d');
    const imgRequest = new Promise((resolve, reject) => {
      img.crossOrigin = '';
      img.onload = () => {
        img.width = img.naturalWidth;
        img.height = img.naturalHeight;
        const datasetBytesBuffer = new ArrayBuffer(NUM_DATASET_ELEMENTS * IMAGE_SIZE * 4);
        const chunkSize = 5000;
        canvas.width = img.width;
        canvas.height = chunkSize;

        for (let i = 0; i < NUM_DATASET_ELEMENTS / chunkSize; i++) {
          const datasetBytesView = new Float32Array(
            datasetBytesBuffer,
            i * IMAGE_SIZE * chunkSize * 4,
            IMAGE_SIZE * chunkSize
          );
          ctx.drawImage(img, 0, i * chunkSize, img.width, chunkSize, 0, 0, img.width, chunkSize);

          const imageData = ctx.getImageData(0, 0, canvas.width, canvas.height);

          for (let j = 0; j < imageData.data.length / 4; j++) {
            datasetBytesView[j] = imageData.data[j * 4] / 255;
          }
        }
        this.datasetImages = new Float32Array(datasetBytesBuffer);

        resolve();
      };
      img.src = MNIST_IMAGES_SPRITE_PATH;
    });

    const labelsRequest = fetch(MNIST_LABELS_PATH);
    const [imgResponse, labelsResponse] = await Promise.all([imgRequest, labelsRequest]);

    this.datasetLabels = new Uint8Array(await labelsResponse.arrayBuffer());

    this.trainIndices = tf.util.createShuffledIndices(NUM_TRAIN_ELEMENTS);
    this.testIndices = tf.util.createShuffledIndices(NUM_TEST_ELEMENTS);

    this.trainImages = this.datas
  • 2
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值