tensorflow.js基本使用 卷积神经网络(六)

识别0到9

import $ from 'jquery';
import * as tf from '@tensorflow/tfjs';
import * as tfvis from '@tensorflow/tfjs-vis';
import { getInputs, img2x } from './utils.js';

$(async () => {
  // 预测
  document.querySelector('.pre').onclick = () => {
    if (window.predict) {
      window.predict();
    } else {
      alert('模型未训练完');
    }
  }

  // 写字绘画板↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓
  const canvas = document.getElementsByClassName('cvs')[0];

  canvas.addEventListener('mousemove', (e) => {
    if (e.buttons === 1) {
      const ctx = canvas.getContext('2d');
      ctx.fillStyle = 'rgb(255,255,255)';
      ctx.fillRect(e.offsetX, e.offsetY, 10, 10);
    }
  });

  function clear() {
    const ctx = canvas.getContext('2d');
    ctx.fillStyle = 'rgb(0,0,0)';
    ctx.fillRect(0, 0, 300, 300);
  }

  document.querySelector('.cls').onclick = clear;
  clear();
  // 写字绘画板↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑

  const { inputs, labels, testInputs, testLabels } = await getInputs();

  // 设置训练集数据
  const { xs, ys } = tf.tidy(() => {
    const xs = tf.concat(inputs.map((item) => img2x(item)));
    const ys = tf.tensor(labels);
    return { xs, ys };
  });

  //设置测试集数据
  const { tx, ty } = tf.tidy(() => {
    const tx = tf.concat(testInputs.map(item => img2x(item)));
    const ty = tf.tensor(testLabels);
    return { tx, ty };
  });

  console.log(xs, ys);
  console.log(tx, ty);

  const surface = tfvis.visor().surface({ name: '输入示例', styles: { height: 255 } });

  inputs.forEach(imgEl => {
    imgEl.width = 28;
    imgEl.height = 28;
    imgEl.style = 'margin:4px;';
    surface.drawArea.appendChild(imgEl);
  });

  const model = tf.sequential();

  //卷积层
  model.add(tf.layers.conv2d({
    inputShape: [224, 224, 3],
    kernelSize: 10, //卷积盒个数
    filters: 10, //特征数量
    strides: 1, //卷积盒扫描时移动步长
    activation: 'relu', //激活函数,用于去除不常见特征
    kernelInitializer: 'varianceScaling', //初始化方法
  }));

  //最大池化层
  model.add(tf.layers.maxPooling2d({
    poolSize: [3, 3], //池化尺寸(可调)
    strides: [3, 3] //移动步数(可调)
  }));

  model.add(tf.layers.conv2d({
    kernelSize: 10,//卷积盒个数
    filters: 10,//特征数
    activation: 'relu',
    kernelInitializer: 'varianceScaling'
  }));

  //最大池化层
  model.add(tf.layers.maxPooling2d({
    poolSize: [3, 3], //池化尺寸(可调)
    strides: [3, 3] //移动步数(可调)
  }));

  //高维度数据,变到低维度
  model.add(tf.layers.flatten());

  //全连接层分类
  model.add(tf.layers.dense({
    units: 10, //0~9 10种分类
    activation: 'softmax', //多分类
    kernelInitializer: 'varianceScaling' //优化
  }));

  //设置损失函数,优化函数
  model.compile({
    loss: 'categoricalCrossentropy',
    optimizer: tf.train.adam(),
    metrics: 'accuracy'
  });

  await model.fit(xs, ys, {
    validationData: [tx, ty],//设置验证集
    epochs: 20,
    callbacks: tfvis.show.fitCallbacks(
      { name: '训练效果' },
      //损失 验证集损失 精确度 验证集精确度
      ['loss', 'val_loss', 'acc', 'val_acc'],
      { callbacks: ['onEpochEnd'] }
    )
  });

  window.predict = () => {
    const input = tf.tidy(() => {
      return tf.image.resizeBilinear(
        tf.browser.fromPixels(canvas),
        [224, 224],
        true,
      )
        .toFloat()//归一化
        .div(255)
        .reshape([1, 224, 224, 3])//与模型数据结构保持一致
    });

    const pred = model.predict(input).argMax(1);
    alert(`预测结果:${pred.dataSync()[0]}`);
  }

  //保存模型
  window.download = async () => {
    await model.save('downloads://my-model');
  }
});

html部分

<!DOCTYPE html>
<html lang="en">
<head>
  <meta charset="UTF-8">
  <meta name="viewport" content="width=device-width, initial-scale=1.0">
  <title>Document</title>
</head>
<body>
  <div>卷积神经网络</div>
  <canvas class="cvs" width="300" height="300" style="border:2px solid black"></canvas>
  <div>
    <button class="cls">清除</button>
    <button class="pre">预测</button>
    <button onclick="download()">保存模型</button>
  </div>
  
</body>
<script src="./t5.js"></script>
</html>

utils.js

import * as tf from '@tensorflow/tfjs';

//载入测试图片的方法↓↓↓↓↓↓↓↓↓↓↓
const IMAGE_SIZE = 224;

const loadImg = (src) => {
  return new Promise(resolve => {
    const img = new Image();
    img.crossOrigin = "anonymous";
    img.src = src;
    img.width = IMAGE_SIZE;
    img.height = IMAGE_SIZE;
    img.onload = () => resolve(img);
  });
};

const pathArr = ['http://127.0.0.1:8080/nums', 'http://127.0.0.1:8080/num-test'];

export const getInputs = async () => {
  let res = [];
  for (let j in pathArr) {
    const loadImgs = [];
    const labels = [];
    for (let i = 0; i < 10; i += 1) {
      ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'].forEach(label => {
        const src = `${pathArr[j]}/${label}-${i}.png`;
        const img = loadImg(src);
        loadImgs.push(img);
        labels.push([
          label === '0' ? 1 : 0,
          label === '1' ? 1 : 0,
          label === '2' ? 1 : 0,
          label === '3' ? 1 : 0,
          label === '4' ? 1 : 0,
          label === '5' ? 1 : 0,
          label === '6' ? 1 : 0,
          label === '7' ? 1 : 0,
          label === '8' ? 1 : 0,
          label === '9' ? 1 : 0,
        ]);
      });
    }
    const inputs = await Promise.all(loadImgs);
    res.push({ inputs, labels });
  }
  return {
    inputs: res[0].inputs,
    labels: res[0].labels,
    testInputs: res[1].inputs,
    testLabels: res[1].labels
  };
}
//载入测试图片的方法↑↑↑↑↑↑↑↑↑↑↑

//图片格式转换↓↓↓↓↓↓↓↓↓↓↓
export function img2x(imgEl) {
  return tf.tidy(() => {
    const input = tf.browser.fromPixels(imgEl)
      .toFloat()
      .sub(255 / 2)
      .div(255 / 2)
      .reshape([1, 224, 224, 3]);
    return input;
  });
}
//图片格式转换↑↑↑↑↑↑↑↑↑↑↑

执行结果

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值