原文链接: tfjs mnist 手写数字识别web版
上一篇: tensorflow 线性模型保存为pb格式,并且在tfjs中使用
下一篇: pytorch 环境搭建 安装使用
效果
模型下载
链接:https://pan.baidu.com/s/1D7ULNIgIJlZAxWmhcDnlqw
提取码:1hga
包含三个文件
vue文件,挂载后加载模型,然后使用模型进行预测
使用自定义函数将图像画在canvas中
let ctx = document.getElementById('num')
tf.toPixels(tf.tensor(this.img, [28, 28, 1]), ctx)
需要安装tfjs和vconsole(可以不装,用于移动端调试)
输入张量包含两个,一个为图像矩阵1*28*28*1 一个是keepProb的值
加载后的model中可以看到需要的输入和能够得到的输出
<template>
<div id="app">
<div class="main" @mouseup.prevent="isDraw=false" @mousedown.prevent="isDraw=true">
<div class="grid">
<div :class="img[i-1]==1?'cell_black':'cell_gray'" v-for="i in 28*28" @mouseover="isDraw && draw(i)"></div>
</div>
<canvas class="num" id="num"></canvas>
<h3>预测结果:{{num}}</h3>
<div class="btns">
<button @click="reset">reset</button>
<button @click="submit">submit</button>
</div>
</div>
</div>
</template>
<script>
import * as tf from '@tensorflow/tfjs';
import {loadFrozenModel} from '@tensorflow/tfjs-converter';
var VConsole = require('vconsole/dist/vconsole.min.js');
new VConsole();
let t = tf.tensor([1, 2, 3])
console.log(t.dataSync(), typeof t.dataSync())
console.log(t, typeof t)
console.log(t[0], typeof t[0])
let s = JSON.stringify(t)
console.log(s, typeof s)
export default {
name: "draw",
data() {
return {
isDraw: false,
img: Array(28 * 28).fill(0),
num: '?',
model: '',
}
},
methods: {
draw(i) {
this.$set(this.img, i - 1, 0 + !this.img[i - 1])
},
async submit() {
let x = tf.tensor(this.img, [1, 28, 28, 1])
x.print()
console.log(this.img)
let p = this.model.predict({
"Placeholder": x,
"Placeholder_2": tf.scalar(1.),
})
p.print()
console.log(p)
let ctx = document.getElementById('num')
tf.toPixels(tf.tensor(this.img, [28, 28, 1]), ctx)
this.num = p.dataSync()[0]
},
reset() {
this.img = Array(28 * 28).fill(0)
this.num = '?'
}
},
async mounted() {
const MODEL_DIR = './static/tfjs_mnist/';
const MODEL_URL = 'tensorflowjs_model.pb';
const WEIGHTS_URL = 'weights_manifest.json';
this.model = await tf.loadFrozenModel(
MODEL_DIR + MODEL_URL,
MODEL_DIR + WEIGHTS_URL);
console.log(this.model)
}
}
</script>
<style>
.main {
width: 100vw;
height: 100vh;
display: flex;
flex-direction: column;
justify-content: center;
align-items: center;
}
.grid {
display: grid;
grid-template-columns: repeat(28, 1fr);
width: 700px;
height: 700px;
border: 1px solid black;
}
.cell_gray {
border: 1px solid black;
background: rgb(233, 233, 233);
}
.cell_black {
border: 1px solid black;
background: black;
}
.num {
width: 140px;
height: 140px;
border: 1px solid black;
margin: 5px;
}
.btns {
width: 150px;
display: flex;
justify-content: space-between;
margin: 5px;
}
</style>
Python 模型保存代码,需要安装tensorflowjs
数据集下载
链接:https://pan.baidu.com/s/1SpMscvnUcNc3J22zCX0skw
提取码:0jot
注意该模型中含有keepProb参数,也就是说输入张量有两个
将模型保存为pb格式,其中必须指定输出节点名称,然后再转化为tfjs使用的格式,在tfjs中给定输入张量即可完成预测
constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ['prediction'])
with tf.gfile.FastGFile(SAVE_PATH, mode='wb') as f:
f.write(constant_graph.SerializeToString())
tfjs.converters.tf_saved_model_conversion.convert_tf_frozen_model('./pb/graph.pb', 'prediction', './tfjsmodel')
完整代码
import tensorflow as tf
import tensorflow.contrib.slim as slim
import numpy as np
from tensorflow.python.framework import graph_util
import tensorflowjs as tfjs
import os
TRAIN_PATH = "D:/data/Digit Recognizer/train.csv"
TEST_PATH = "D:/data/Digit Recognizer/test.csv"
learning_rate = .0001
TRAIN_STEP = 50000
BATCH_SIZE = 512
SHOW_STEP = 100
in_x = tf.placeholder(tf.float32, (None, 28, 28, 1))
in_y = tf.placeholder(tf.float32, (None, 10))
keep_prob = tf.placeholder(tf.float32)
SAVE_DIR = './pb/'
SAVE_PATH = os.path.join(SAVE_DIR, 'graph.pb')
# 预处理返回的是28*28的图像,需要进行reshape
def preprocessing(img):
return img
# 输入的input 为n*28*28*1
def get_net():
with slim.arg_scope(
[slim.conv2d, slim.fully_connected],
# activation_fn=tf.nn.relu,
activation_fn=tf.nn.relu6,
# activation_fn=tf.nn.leaky_relu,
):
net = slim.conv2d(in_x, 64, (3, 3), stride=1)
net = slim.max_pool2d(net, (2, 2), stride=2, padding='SAME')
print(net.shape) # (?, 14, 14, 64)
net = slim.conv2d(net, 32, (3, 3), stride=1)
net = slim.max_pool2d(net, (2, 2), stride=2, padding='SAME')
print(net.shape) # (?, 7, 7, 32)
net = slim.conv2d(net, 16, (3, 3), stride=1)
net = slim.max_pool2d(net, (2, 2), stride=2, padding='SAME')
print(net.shape) # (?, 4, 4, 16)
net = slim.flatten(net)
print(net.shape) # (?, 16)
net = slim.fully_connected(net, 128)
net = slim.dropout(net, keep_prob)
print(net.shape) # (?, 128)
net = slim.fully_connected(net, 32)
net = slim.dropout(net, keep_prob)
print(net.shape) # (?, 32)
net = slim.fully_connected(net, 10)
net = slim.dropout(net, keep_prob)
print(net.shape) # (?, 10)
return net
net = get_net()
# loss_op = tf.reduce_mean((net - in_y) ** 2)
# loss_op = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=net, logits=in_y))
loss_op = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=in_y, logits=net))
train_op = tf.train.AdamOptimizer(learning_rate).minimize(loss_op)
prediction_result = tf.argmax(net, 1, name='prediction')
prediction_tr = tf.equal(tf.argmax(net, 1), tf.argmax(in_y, 1))
accuracy_tr = tf.reduce_mean(tf.cast(prediction_tr, tf.float32))
sess = tf.Session()
sess.run(tf.global_variables_initializer())
# 根据data获取求最终结果
# image_data N*28*28*1
# ans N*1
def get_ans(image_data):
ans = sess.run(net, {
in_x: image_data,
keep_prob: 1.
})
ans = np.argmax(ans, axis=1)
return ans
# 根据data和对应的label训练模型,并保存模型
# data N*28*28*1
# label N*10
def train(data, label):
id_list = range(len(data))
for i in range(1, 1 + TRAIN_STEP):
ids = np.random.choice(id_list, BATCH_SIZE)
image_batch = data[ids]
label_batch = label[ids]
sess.run(train_op, {
in_x: image_batch,
in_y: label_batch,
keep_prob: .5,
})
if not i % SHOW_STEP:
ids = np.random.choice(id_list, BATCH_SIZE * 10)
image_batch = data[ids]
label_batch = label[ids]
loss_val, accuracy_val = sess.run(
[loss_op, accuracy_tr], {
in_x: image_batch,
in_y: label_batch,
keep_prob: 1.,
}
)
print(i, loss_val, accuracy_val)
def main():
data = np.loadtxt(TRAIN_PATH, dtype=np.str, delimiter=',')
image_data = data[1:, 1:].reshape((-1, 28, 28, 1)).astype(np.uint8)
image_data = np.stack(
[
preprocessing(img).reshape((28, 28, 1))
for img in image_data
]
).astype(np.float32)
image_data /= 255
label_data = tf.keras.utils.to_categorical(data[1:, 0], 10).astype(np.float32)
train(image_data, label_data)
test_data = np.loadtxt(TEST_PATH, dtype=np.str, delimiter=',')
image_data = test_data[1:, :].reshape((-1, 28, 28, 1)).astype(np.uint8)
image_data = np.stack(
[
preprocessing(img).reshape((28, 28, 1))
for img in image_data
]
).astype(np.float32)
image_data /= 255
# 直接传入全部测试数据,计算量很大,也占用内存
# ans = get_ans(image_data)
# 分批次计算,然后聚合结果
ans = []
for batch in np.split(image_data, 20):
ans.append(get_ans(batch))
ans = np.hstack(ans)
print(ans.shape)
# 将结果写入文件
with open('ans.txt', mode='w+', encoding='utf8') as f:
f.write('ImageId,Label\n')
for i, j in enumerate(ans):
f.write(f"{i+1},{j}\n")
constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ['prediction'])
with tf.gfile.FastGFile(SAVE_PATH, mode='wb') as f:
f.write(constant_graph.SerializeToString())
tfjs.converters.tf_saved_model_conversion.convert_tf_frozen_model('./pb/graph.pb', 'prediction', './tfjsmodel')
if __name__ == '__main__':
main()