tensorflow.js基本使用 XOR(三)

示例 

import $ from 'jquery';
import * as tf from '@tensorflow/tfjs';
import * as tfvis from '@tensorflow/tfjs-vis';

function getData(numSamples) {
  let points = [];

  function genGauss(cx, cy, label) {
    for (let i = 0; i < numSamples / 2; i++) {
      let x = normalRandom(cx);
      let y = normalRandom(cy);
      points.push({ x, y, label });
    }
  }

  genGauss(2, 2, 0);
  genGauss(-2, -2, 0);
  genGauss(-2, 2, 1);
  genGauss(2, -2, 1);
  return points;
}

function normalRandom(mean = 0, variance = 1) {
  let v1, v2, s;
  do {
    v1 = 2 * Math.random() - 1;
    v2 = 2 * Math.random() - 1;
    s = v1 * v1 + v2 * v2;
  } while (s > 1);

  let result = Math.sqrt(-2 * Math.log(s) / s) * v1;
  return mean + Math.sqrt(variance) * result;
}

$(async () => {
  $('#xf').on('submit',()=>{
    if(window.predict){
      window.predict({x:$('#xf #x').val(),y:$('#xf #y').val()});
    }else{
      alert('模型还没有训练完');
    }
    return false;
  });

  const data = getData(400);
  console.log(data);

  tfvis.render.scatterplot(
    { name: "xor" },
    {
      values: [
        data.filter(item => item.label === 1),
        data.filter(item => item.label === 0)
      ]
    }
  );

  const model = tf.sequential();

  //设置全连接层
  model.add(tf.layers.dense({
    units: 9,
    inputShape: [2],
    activation: 'relu'
  }));

  //设置输出层
  model.add(tf.layers.dense({
    units: 1,
    activation: 'sigmoid'
  }));

  //设置损失函数,优化器
  model.compile({
    loss: tf.losses.logLoss,
    optimizer: tf.train.adam(0.1)
  });

  const inputs = tf.tensor(data.map((item) => [item.x, item.y]));
  const labels = tf.tensor(data.map((item) => item.label));

  await model.fit(inputs, labels, {
    epochs: 10,
    callbacks: tfvis.show.fitCallbacks(
      { name: '训练过程' },
      ['loss']
    )
  });

  window.predict = async (form) => {
    const pred = await model.predict(tf.tensor([[form.x * 1, form.y * 1]]));
    alert(`预测结果${pred.dataSync()[0]}`);
  }

  //保存模型
  window.download = async () => {
    await model.save('downloads://my-model');
  }
});

html部分

<!DOCTYPE html>
<html lang="en">
<head>
  <meta charset="UTF-8">
  <meta name="viewport" content="width=device-width, initial-scale=1.0">
  <title>Document</title>
</head>
<body>
  <div>xor数据集</div>
  <form id="xf">
    <label for="x">
      x<input type="text" name="x" id="x"/>
    </label><br/>
    <label for="y">
      y<input type="text" name="y" id="y"/>
    </label><br/>
    <input type="submit" value="提交">
  </form>
  <button onclick="download()">保存模型</button>
</body>
<script src="./t3.js"></script>
</html>

执行结果

 

 

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值