【tensorflow.js学习笔记(5)】使用RNN学习“加法运算”

tensorflow.js实现了几种RNN的接口,包括SimpleRNN、GRU和LSTM。这篇笔记介绍如何在浏览器环境下利用tensorflow.js训练RNN学习加法运算,即给出一个加法算式的字符串,算出数字结果,类似于自然语言处理。

1、生成训练、测试数据

// digits-每个字符位数,trainingSize-训练集大小
function generateData(digits, trainingSize) {
  // 所有可选字符集
  const digitArray = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'];
  const arraySize = digitArray.length;

  // 输出
  const output = [];
  const maxLen = digits + 1 + digits;

  // 从digitArray挑选digits个数据拼为一个数字
  const f = () => {
    let str = '';
    while (str.length < digits) {
      const index = Math.floor(Math.random() * arraySize);
      str += digitArray[index];
    }
    return Number.parseInt(str);
  };

  // 生成trainingSize组数据
  while (output.length < trainingSize) {
    const a = f();
    const b = f();

    const q = `${a}+${b}`;
    // 补空格
    const query = q + ' '.repeat(maxLen - q.length);
    let ans = (a + b).toString();
    // 补空格
    ans += ' '.repeat(digits + 1 - ans.length);
    output.push([query, ans]);
  }
  return output;
}

digits代表输入数字的位数,比如567的位数是3。函数f从digitArray中随机挑选digits个数拼为一个输入。输入a、加号、输入b整体拼为一个query,a+b的真实结果拼为ans。为防止第一个数字为0改变数字位数,query和ans均后补空格,函数返回query、ans字符对。

2、数据分组并转为tensor

// 90%训练集,10%测试集
const split = Math.floor(trainingSize * 0.9);
this.trainData = data.slice(0, split);
this.testData = data.slice(split);

// 转为tensors,并分为训练组、测试组
[this.trainXs, this.trainYs] = convertDataToTensors(this.trainData, this.charTable, digits);
[this.testXs, this.testYs] = convertDataToTensors(this.testData, this.charTable, digits);

将generateData生成的数据分为训练组和验证组,并将字符串转为tensor。转换函数converDataToTensors如下。

function convertDataToTensors(data, charTable, digits) {
  const maxLen = digits + 1 + digits;
  // data中每一项datum = [query, ans]
  const questions = data.map(datum => datum[0]);
  const answers = data.map(datum => datum[1]);
  return [
    charTable.encodeBatch(questions, maxLen),
    charTable.encodeBatch(answers, digits + 1),
  ];
}

对query、ans编码,需要字符集类CharacterTable。

class CharacterTable {
  constructor(chars) {
    this.chars = chars;
    // 字符-位置index
    this.charIndices = {};
    // 位置index-字符
    this.indicesChar = {};
    this.size = this.chars.length;
    for (let i = 0; i < this.size; ++i) {
      const char = this.chars[i];
      this.charIndices[this.chars[i]] = i;
      this.indicesChar[i] = this.chars[i];
    }
  }

  // 输入questions、answers数组,输出转化的tensor
  encodeBatch(strings, maxLen) {
    const numExamples = strings.length;
    const buf =
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值