目的为整合所学知识,捡一下以前学习的代码,编程语言包括JavaScript、Python、HTML。涉及到神经网络、网页设计、网页异步加载,还有一些配置问题。只是个简单的demo,连准确率都不怎么在乎,主要是设计流程。
一、网页的设计
网页功能主要分为几块,
1.写数字的部分
2.显示结果的部分
3.提交部分
4.清除/重置按钮
1.写数字的部分
利用HTML5里的<canvas>来实现,实现代码来源于网上,我这里只做了个粘合剂,直接拿来用了。
定义id为canvas,宽和高为280px的画布
<canvas id="canvas" width="280" height="280" style="margin-bottom: 10px;"></canvas>
JavaScript,主要是注册画布的监听函数和配置画笔属性
var isDown = false;
var points = [];
var beginPoint = null;
const canvas = document.querySelector('#canvas');
const ctx = canvas.getContext('2d');
// 设置线条颜色
ctx.strokeStyle = 'red';
ctx.lineWidth = 40;
ctx.lineJoin = 'round';
ctx.lineCap = 'round';
canvas.addEventListener('mousedown', down, false);
canvas.addEventListener('mousemove', move, false);
canvas.addEventListener('mouseup', up, false);
canvas.addEventListener('mouseout', up, false);
function down(evt) {
isDown = true;
const { x, y } = getPos(evt);
points.push({x, y});
beginPoint = {x, y};
}
function move(evt) {
if (!isDown) return;
const { x, y } = getPos(evt);
points.push({x, y});
if (points.length > 3) {
const lastTwoPoints = points.slice(-2);
const controlPoint = lastTwoPoints[0];
const endPoint = {
x: (lastTwoPoints[0].x + lastTwoPoints[1].x) / 2,
y: (lastTwoPoints[0].y + lastTwoPoints[1].y) / 2,
}
drawLine(beginPoint, controlPoint, endPoint);
beginPoint = endPoint;
}
}
function up(evt) {
if (!isDown) return;
const { x, y } = getPos(evt);
points.push({x, y});
if (points.length > 3) {
const lastTwoPoints = points.slice(-2);
const controlPoint = lastTwoPoints[0];
const endPoint = lastTwoPoints[1];
drawLine(beginPoint, controlPoint, endPoint);
}
beginPoint = null;
isDown = false;
points = [];
}
function getPos(evt) {
return {
x: evt.clientX,
y: evt.clientY
}
}
function drawLine(beginPoint, controlPoint, endPoint) {
ctx.beginPath();
ctx.moveTo(beginPoint.x, beginPoint.y);
ctx.quadraticCurveTo(controlPoint.x, controlPoint.y, endPoint.x, endPoint.y);
ctx.stroke();
ctx.closePath();
}
function showLoading(show){
if(show)
{
document.getElementById("over").style.display = "block";
document.getElementById("layout").style.display = "block";
}
else
{
document.getElementById("over").style.display = "none";
document.getElementById("layout").style.display = "none";
}
}
2.显示结果的部分
显示结果,就定义了一个文本框,让识别的结果显示在这里就好。文本框id=‘txt’。
<div>
<a>结果为:</a><input name="txt" type="text" id = "result"/>
</div>
3.提交部分
提交部分的功能要相对复杂一点,要实现的目标功能是:按下提交按钮,浏览器将画布上的数据传输到服务器,服务器将这个数据送入神经网络,神经网络给出识别结果,服务器将结果送回浏览器,浏览器将结果展示在显示结果的文本框中。
按钮好实现,代码如下
<input value="提交" type="button" id = "submit" style="margin-left: 300px;"/>
功能实现我考虑使用异步加载技术,因为我Django框架就简单看了一会,对Django框架设计网页的风格还接受不了,又想尽快实现功能。最终采用了JQuery的异步加载方法。
JQuery对提交按钮注册的click函数
$("#submit").click(function()//#sumbit 是JQuery的id选择器用法,即选中了id为submit的页面元素
{
showLoading(true);//这个函数我后面会交代,为了增加体验做了一个等待的界面,按下提交按钮,就弹出等待动画
//直到把对应属性置为false
var testimg = [];
//var senddata = []
var imgData=ctx.getImageData(0,0,canvas.width,canvas.height);//拿到画布数据
for (var i=0;i<imgData.data.length;i+=4)
{
testimg[i/4] = imgData.data[i];//这里请看一下canvas的文档,我只拿了red图层的数据,因为画笔是红色的
//拿其他数据没有意义
}
console.log(testimg.length);
//console.log(testimg);
$.post("http://***.***.***.***:8000/hello",{'cd':testimg.toString()},function(data,status){
/*
这里使用jQuery的post方法,将数据转成字符串送给服务器,这里有个问题,如果不转字符串,服务器
接收到的数据会有问题,可以自行验证。
*/
showLoading(false);//这里置flase了,等待界面消失
document.getElementById("result").value = data;//这里将服务器返回的数据展示到文本框中
console.log(data.length);
alert("Data: " + data + "\nStatus: " + status);
});
//alert("submit success!!");
});
请注意,这里有两个问题:
1.开始时我网页是在PC端设计并调试,Django是在服务器端,二者不是同一IP,所以浏览器在post数据给服务器会因跨域问题而拒绝访问,请自行百度Django允许跨域设置。
2.post数据太大,超出Django默认大小。解决方法:看一下Django报错信息,看看是那个字段的大小限制了post的数据量,并将这个字段在对应工程的settings.py中赋值为None。
4.清除/重置按钮
这个就比较简单了。
显示代码
<input value="清除" type="button" id = "cls"/>
对应JavaScript代码,也是用的JQuery框架的写法,注意下
$("#cls").click(function()
{
var clearimgdata = ctx.createImageData(canvas.width,canvas.height);
ctx.putImageData(clearimgdata,0,0);//画布清零
document.getElementById("result").value = '';//清除结果
alert("cls success!!");
});
ok,贴一下页面总体图

界面可简单,但是可以用*_*"。
二、服务器端的设计
也就是所谓的后台设计了。根据菜鸟教程,创建django工程


templates是后来加上的,如果你有MVC设计模式的经验,就能很快上手。
在这个工程里我忽略了模板的作用,主要是view.py与urls.py的设置与对应。注:这里的view.py也是自己创建的。
解释一下,url顾名思义,就是你使用浏览器输入的网址,输入完网址,回车!服务器会给你的浏览器一个响应,那这个响应是如何得出来的呢?得出逻辑就写在了view.py里,urls.py中注册网址与给出响应的函数的映射!!
给出urls.py的代码
from django.conf.urls import url
from . import view
urlpatterns = [
url(r'^hello$',view.hello),
]
#很明显,http://***.***.***/hello这个网址对应着view里的hello函数
给出view.py的代码,需要注意的东西写到代码注释里了
from django.http import HttpResponse
from django.shortcuts import render
import numpy as np
from PIL import Image
import tensorflow.compat.v1 as tf
from . import mnist_backward
from . import mnist_forward#请注意这个点,Django中引用同一文件夹下的格式
def restore_model(testPicArr):#这个就是加载神经网络模型
with tf.Graph().as_default() as tg:
x = tf.placeholder(tf.float32, [None, mnist_forward.INPUT_NODE])
y = mnist_forward.forward(x, None)
preValue = tf.argmax(y, 1)
variable_averages = tf.train.ExponentialMovingAverage(mnist_backward.MOVING_AVERAGE_DECAY)
variables_to_restore = variable_averages.variables_to_restore()
saver = tf.train.Saver(variables_to_restore)
with tf.Session() as sess:
ckpt = tf.train.get_checkpoint_state(mnist_backward.MODEL_SAVE_PATH)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
preValue = sess.run(preValue, feed_dict={x:testPicArr})
return preValue
else:
print("No checkpoint file found")
return -1
def hello(request):
if request.POST:#判断请求形式,如果是post方式传输数据
# keysList = list(request.POST.keys())
getData = list(request.POST.values())[0]
splitData = getData.split(',')
# response = HttpResponse('POST:Hello world!!'+str(list(request.POST.values())[0]))
np_data = np.array(splitData, dtype = np.float32)
b = np.resize(np_data,(280,280))#原始画布数据
b_int = b.astype(np.uint8)
im = Image.fromarray(b_int)
im = im.resize((28,28),Image.ANTIALIAS)# resize画布数据
resdata = np.array(im,dtype = np.float32)/255.0# 归一化
resdata = np.resize(resdata,(1,28*28))# 拉直,用来输入
response = HttpResponse(str(restore_model(resdata)))# 将数据喂入神经网络
elif request.GET:
response = HttpResponse('GET:success!!')
else:
response = HttpResponse('Other:Hello world!!')
response["Access-Control-Allow-Origin"] = "*"
response["Access-Control-Allow-Methods"] = "POST, GET,OPTIONS"
response["Access-Control-Max-Age"] = "1000"
response["Access-Control-Allow-Headers"] = "*"
#print(resquest)
return response
整个工程很简单吧,如果你想要更复杂的工程,只需要找一个好的UI,训练一个好的、高准确率的模型,在将数据喂入神经网络的部分其实还可以在优化,因为将神经网络加载到内存里需要的时间是有点久的,最好做成常驻内存的程序。
三、结果
贴下结果。



四、关于神经网络
就是一个简单的全连接网络,在tf1下训练的,输入是28*28,但是tf2的compat.v1y也使得它可用,就这样。
具体代码请移步:https://github.com/tian0zhi/Django-Tensorflow1-2

746

被折叠的 条评论
为什么被折叠?



