tfjs mnist 手写数字识别web版

原文链接: tfjs mnist 手写数字识别web版

上一篇: tensorflow 线性模型保存为pb格式,并且在tfjs中使用

下一篇: pytorch 环境搭建 安装使用

效果

64e258201b7e93981a382c7fe96ded288d9.jpg

1d7008a0b7925899b439d26c419207a78de.jpg

4404be1053b12267a401113b4376e71f6ab.jpg

模型下载

链接:https://pan.baidu.com/s/1D7ULNIgIJlZAxWmhcDnlqw
提取码:1hga

包含三个文件

6b46e8ff29ad04624598983cdba5ad53ed8.jpg

vue文件,挂载后加载模型,然后使用模型进行预测

使用自定义函数将图像画在canvas中

        let ctx = document.getElementById('num')
        tf.toPixels(tf.tensor(this.img, [28, 28, 1]), ctx)

需要安装tfjs和vconsole(可以不装,用于移动端调试)

输入张量包含两个,一个为图像矩阵1*28*28*1 一个是keepProb的值

加载后的model中可以看到需要的输入和能够得到的输出

96a3e95682ecb2ca763e30309859ca28f09.jpg

<template>
  <div id="app">
    <div class="main" @mouseup.prevent="isDraw=false" @mousedown.prevent="isDraw=true">
      <div class="grid">
        <div :class="img[i-1]==1?'cell_black':'cell_gray'" v-for="i in 28*28" @mouseover="isDraw && draw(i)"></div>
      </div>

      <canvas class="num" id="num"></canvas>
      <h3>预测结果:{{num}}</h3>

      <div class="btns">
        <button @click="reset">reset</button>
        <button @click="submit">submit</button>
      </div>
    </div>
  </div>
</template>

<script>
  import * as tf from '@tensorflow/tfjs';
  import {loadFrozenModel} from '@tensorflow/tfjs-converter';

  var VConsole = require('vconsole/dist/vconsole.min.js');
  new VConsole();
  let t = tf.tensor([1, 2, 3])
  console.log(t.dataSync(), typeof t.dataSync())
  console.log(t, typeof t)
  console.log(t[0], typeof t[0])
  let s = JSON.stringify(t)
  console.log(s, typeof s)
  export default {
    name: "draw",
    data() {
      return {
        isDraw: false,
        img: Array(28 * 28).fill(0),
        num: '?',
        model: '',
      }
    },
    methods: {
      draw(i) {
        this.$set(this.img, i - 1, 0 + !this.img[i - 1])
      },
      async submit() {
        let x = tf.tensor(this.img, [1, 28, 28, 1])
        x.print()
        console.log(this.img)
        let p = this.model.predict({
          "Placeholder": x,
          "Placeholder_2": tf.scalar(1.),
        })
        p.print()
        console.log(p)
        let ctx = document.getElementById('num')
        tf.toPixels(tf.tensor(this.img, [28, 28, 1]), ctx)
        this.num = p.dataSync()[0]
      },
      reset() {
        this.img = Array(28 * 28).fill(0)
        this.num = '?'
      }
    },
    async mounted() {
      const MODEL_DIR = './static/tfjs_mnist/';
      const MODEL_URL = 'tensorflowjs_model.pb';
      const WEIGHTS_URL = 'weights_manifest.json';
      this.model = await tf.loadFrozenModel(
        MODEL_DIR + MODEL_URL,
        MODEL_DIR + WEIGHTS_URL);
      console.log(this.model)
    }
  }

</script>

<style>
  .main {
    width: 100vw;
    height: 100vh;
    display: flex;
    flex-direction: column;
    justify-content: center;
    align-items: center;
  }

  .grid {
    display: grid;
    grid-template-columns: repeat(28, 1fr);
    width: 700px;
    height: 700px;
    border: 1px solid black;
  }

  .cell_gray {
    border: 1px solid black;
    background: rgb(233, 233, 233);
  }

  .cell_black {
    border: 1px solid black;
    background: black;

  }

  .num {
    width: 140px;
    height: 140px;
    border: 1px solid black;
    margin: 5px;
  }

  .btns {
    width: 150px;
    display: flex;
    justify-content: space-between;
    margin: 5px;
  }
</style>

Python 模型保存代码,需要安装tensorflowjs

数据集下载

链接:https://pan.baidu.com/s/1SpMscvnUcNc3J22zCX0skw
提取码:0jot

注意该模型中含有keepProb参数,也就是说输入张量有两个

将模型保存为pb格式,其中必须指定输出节点名称,然后再转化为tfjs使用的格式,在tfjs中给定输入张量即可完成预测

    constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ['prediction'])
    with tf.gfile.FastGFile(SAVE_PATH, mode='wb') as f:
        f.write(constant_graph.SerializeToString())

    tfjs.converters.tf_saved_model_conversion.convert_tf_frozen_model('./pb/graph.pb', 'prediction', './tfjsmodel')

完整代码

import tensorflow as tf
import tensorflow.contrib.slim as slim
import numpy as np
from tensorflow.python.framework import graph_util
import tensorflowjs as tfjs
import os

TRAIN_PATH = "D:/data/Digit Recognizer/train.csv"
TEST_PATH = "D:/data/Digit Recognizer/test.csv"
learning_rate = .0001
TRAIN_STEP = 50000
BATCH_SIZE = 512
SHOW_STEP = 100

in_x = tf.placeholder(tf.float32, (None, 28, 28, 1))
in_y = tf.placeholder(tf.float32, (None, 10))
keep_prob = tf.placeholder(tf.float32)
SAVE_DIR = './pb/'
SAVE_PATH = os.path.join(SAVE_DIR, 'graph.pb')


# 预处理返回的是28*28的图像,需要进行reshape
def preprocessing(img):
    return img


# 输入的input 为n*28*28*1
def get_net():
    with slim.arg_scope(
            [slim.conv2d, slim.fully_connected],

            # activation_fn=tf.nn.relu,
            activation_fn=tf.nn.relu6,
            # activation_fn=tf.nn.leaky_relu,
    ):
        net = slim.conv2d(in_x, 64, (3, 3), stride=1)
        net = slim.max_pool2d(net, (2, 2), stride=2, padding='SAME')
        print(net.shape)  # (?, 14, 14, 64)
        net = slim.conv2d(net, 32, (3, 3), stride=1)
        net = slim.max_pool2d(net, (2, 2), stride=2, padding='SAME')
        print(net.shape)  # (?, 7, 7, 32)
        net = slim.conv2d(net, 16, (3, 3), stride=1)
        net = slim.max_pool2d(net, (2, 2), stride=2, padding='SAME')
        print(net.shape)  # (?, 4, 4, 16)

        net = slim.flatten(net)
        print(net.shape)  # (?, 16)
        net = slim.fully_connected(net, 128)
        net = slim.dropout(net, keep_prob)
        print(net.shape)  # (?, 128)
        net = slim.fully_connected(net, 32)
        net = slim.dropout(net, keep_prob)
        print(net.shape)  # (?, 32)
        net = slim.fully_connected(net, 10)
        net = slim.dropout(net, keep_prob)
        print(net.shape)  # (?, 10)
        return net


net = get_net()
# loss_op = tf.reduce_mean((net - in_y) ** 2)
# loss_op = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=net, logits=in_y))
loss_op = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=in_y, logits=net))
train_op = tf.train.AdamOptimizer(learning_rate).minimize(loss_op)
prediction_result = tf.argmax(net, 1, name='prediction')
prediction_tr = tf.equal(tf.argmax(net, 1), tf.argmax(in_y, 1))
accuracy_tr = tf.reduce_mean(tf.cast(prediction_tr, tf.float32))

sess = tf.Session()
sess.run(tf.global_variables_initializer())


# 根据data获取求最终结果
# image_data N*28*28*1
# ans N*1
def get_ans(image_data):
    ans = sess.run(net, {
        in_x: image_data,
        keep_prob: 1.
    })
    ans = np.argmax(ans, axis=1)
    return ans


# 根据data和对应的label训练模型,并保存模型
# data N*28*28*1
# label N*10
def train(data, label):
    id_list = range(len(data))
    for i in range(1, 1 + TRAIN_STEP):
        ids = np.random.choice(id_list, BATCH_SIZE)
        image_batch = data[ids]
        label_batch = label[ids]
        sess.run(train_op, {
            in_x: image_batch,
            in_y: label_batch,
            keep_prob: .5,
        })

        if not i % SHOW_STEP:
            ids = np.random.choice(id_list, BATCH_SIZE * 10)
            image_batch = data[ids]
            label_batch = label[ids]
            loss_val, accuracy_val = sess.run(
                [loss_op, accuracy_tr], {
                    in_x: image_batch,
                    in_y: label_batch,
                    keep_prob: 1.,
                }
            )
            print(i, loss_val, accuracy_val)


def main():
    data = np.loadtxt(TRAIN_PATH, dtype=np.str, delimiter=',')
    image_data = data[1:, 1:].reshape((-1, 28, 28, 1)).astype(np.uint8)

    image_data = np.stack(
        [
            preprocessing(img).reshape((28, 28, 1))
            for img in image_data
        ]
    ).astype(np.float32)
    image_data /= 255
    label_data = tf.keras.utils.to_categorical(data[1:, 0], 10).astype(np.float32)
    train(image_data, label_data)

    test_data = np.loadtxt(TEST_PATH, dtype=np.str, delimiter=',')
    image_data = test_data[1:, :].reshape((-1, 28, 28, 1)).astype(np.uint8)

    image_data = np.stack(
        [
            preprocessing(img).reshape((28, 28, 1))
            for img in image_data
        ]
    ).astype(np.float32)

    image_data /= 255

    # 直接传入全部测试数据,计算量很大,也占用内存
    # ans = get_ans(image_data)

    # 分批次计算,然后聚合结果
    ans = []
    for batch in np.split(image_data, 20):
        ans.append(get_ans(batch))
    ans = np.hstack(ans)

    print(ans.shape)

    # 将结果写入文件
    with open('ans.txt', mode='w+', encoding='utf8') as f:
        f.write('ImageId,Label\n')
        for i, j in enumerate(ans):
            f.write(f"{i+1},{j}\n")

    constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ['prediction'])
    with tf.gfile.FastGFile(SAVE_PATH, mode='wb') as f:
        f.write(constant_graph.SerializeToString())

    tfjs.converters.tf_saved_model_conversion.convert_tf_frozen_model('./pb/graph.pb', 'prediction', './tfjsmodel')


if __name__ == '__main__':
    main()

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值