一、KNN手写识别原理
在下图中,要判断绿色圆归属为哪个类(红三角形还是蓝四边形)
如果K=3,因为红三角形占比例为2/3,所以绿色圆归属为红色三角形;
如果K=5,因为蓝四边形比例为3/5,所以绿色圆被赋予蓝色四方形类。
那么如何计算上图中各个图形距离绿圆的距离呢?KNN使用的是欧氏距离原理:(可以把图形当作一个点)计算两个点间的距离
这个是二维的,如果是三维或者多维的欧氏距离:
两个n维向量a(x1,x2,…,xn)与b(x1,x2,…,xn)间的欧式距离
二、基本思路
- 数字保存的画布大小用32×32的二维数组表示。
- 如下图中一个大方格相当于10×10的像素,一个大方格就是32×32的二维数组中的一个元素
- 当画笔画到这个方格的区域时,就将这个方格以及周围的8个方格(上下左右,斜线)都标记为1(一开始我只标记一个,效果很差,要达到下图2效果),其他没有划线部分初始化为0
- 计算当前手写的和样本的欧式距离,按距离升序排序,取前K个。
- 从前K个排序好的距离中,取出出现频率最高的那个数字
- 为了弥补KNN的误差,可以把与样本匹配到的数字(可能结果)都输出来,以供用户选择
- 除以上的思路外,还可以自己添加训练集,提高识别准确率,增加一个保存的功能
- 思维导图
借鉴博客:https://blog.csdn.net/weixin_42621338/article/details/81989035
看效果(默认识别功能,可选择,写完数字后点击鼠标即可实现识别)
三、附录代码
- 目前代码还不是很完善,没有加入点击判断(是否写有数字)
- 没有橡皮擦~ 简陋版就先这样,慢慢优化
用到两个文件操作:写入、读取。
写
public void out() throws IOException{
//创建文件名为:保存为数字+当前时间
File ins=new File("D:/learning/mydemo/Javaworkspace/rect/src/com/手写识别/"+number+'-'+System.currentTimeMillis());
//写入文件
FileWriter write=new FileWriter(ins);
for(int i=0;i<Area.arraySize;i++){
for(int j=0;j<Area.arraySize;j++){
write.write(draw[i][j]+"");
}
}
write.flush();
write.close();
}
读
public void in() throws IOException{
File file =new File ("D:/learning/mydemo/Javaworkspace/rect/src/com/手写识别/");
//得到该目录下的所有文件名
String[] filelist=file.list();
for(int i=0;i<filelist.length;i++){
FileReader readfile=new FileReader("D:/learning/mydemo/Javaworkspace/rect/src/com/手写识别/"+filelist[i]);
//读取文件,输入流读取返回int
int b=readfile.read();
for(int k=0;k<Area.arraySize;k++){
for(int j=0;j<Area.arraySize;j++){
//文章末为-1
if(b!=-1){
//从文件名中取出的数存入数组
getfile[k][j]=Integer.valueOf((char)b+"");
//读取下一个
b=readfile.read();
}
}
}
//关闭文件
readfile.close();
//提取文件名首数字
int a =Integer.valueOf(filelist[i].substring(0, 1));
// System.out.println("本次检查数字:"+a);
//调用计算距离公式,得到的距离传给距离数组
setdistance(getfile,draw,a);
}
}
显示窗体:
package com.Digital;
import java.awt.Color;
import java.awt.Dimension;
import java.awt.FlowLayout;
import java.awt.Graphics;
import java.awt.TextArea;
import javax.swing.JButton;
import javax.swing.JComboBox;
import javax.swing.JFrame;
import javax.swing.JPanel;
public class ShowUI {
/**
* 主函数
* @param argc
*/
public static void main(String [] argc){
ShowUI s=new ShowUI();
s.show();
}
/**
* 显示窗体方法
*/
public void show(){
/*
* 创建窗体步骤:
* 1.设置窗体名字
* 2.设置窗体大小
* 3.设置窗体关闭方法
* 4.设置窗体居中显示
* 5.设置窗体布局
* 6.设置窗体可见
*/
JFrame jf=new JFrame("手写识别");
jf.setSize(400, 480);
jf.setDefaultCloseOperation(3);
jf.setLocationRelativeTo(null);
jf.setLayout(new FlowLayout());//流式布局
//添加功能按钮
JButton iden=new JButton("识别");
JButton train=new JButton("训练");
jf.add(iden);
jf.add(train);
//下拉框
String []number={"0","1","2","3","4","5","6","7","8","9"};
JComboBox<String> num=new JComboBox<String>(number);
jf.add(num);
//写字识别处
JPanel jpw=new JPanel();
jpw.setPreferredSize(new Dimension(Area.Size, Area.Size));
jpw.setBackground(Color.white);
jf.add(jpw);
//显示识别结果,用TextArea(自带滚动条),如果是JTextArea需要自行绑定滚动条
TextArea jps=new TextArea(2,30);
jps.setBackground(Color.white);
jf.add(jps);
jf.setVisible(true);
//获取JPanel上的画笔
Graphics g=jpw.getGraphics();
//新建监听器,下面在各个组件上添加
Listener l=new Listener(g, jpw,jps);
iden.addActionListener(l);
train.addActionListener(l);
num.addItemListener(l);
jpw.addMouseListener(l);
jpw.addMouseMotionListener(l);
}
/**
* 基本设计,画布大小
* @author mo
*
*/
public interface Area{
int Size=320;
int arraySize=32;
int K=3;
}
}
监听器:
package com.Digital;
import java.awt.BasicStroke;
import java.awt.Graphics;
import java.awt.Graphics2D;
import java.awt.TextArea;
import java.awt.event.ActionEvent;
import java.awt.event.ActionListener;
import java.awt.event.ItemEvent;
import java.awt.event.ItemListener;
import java.awt.event.MouseEvent;
import java.awt.event.MouseListener;
import java.awt.event.MouseMotionListener;
import java.io.File;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import javax.swing.JComboBox;
import javax.swing.JPanel;
import javax.swing.JTextArea;
import com.Digital.ShowUI.Area;
import com.MyArrayList.MyArrayList;
public class Listener implements ActionListener,MouseMotionListener,MouseListener,ItemListener {
private int [][]draw=new int [Area.arraySize][Area.arraySize];//保存每一个位置的二维数组
private int [][]getfile=new int[Area.arraySize][Area.arraySize];
private int []count=new int[10];
private MyArrayList<Distance> distance=new MyArrayList<Distance>();
private Graphics2D g;
private int x1,y1;
private JPanel jp;
private TextArea jt;
private String name="识别";
private int number;
/**
* 二维数组的初始化
*/
public void init(){
for(int i=0;i<Area.arraySize;i++){
for(int j=0;j<Area.arraySize;j++){
draw[i][j]=0;
}
}
int []c={0,0,0,0,0,0,0,0,0,0};
count=c;
distance.clear();
}
/**
* 构造传参
* 画笔+画板
* @param g
* @param jp
*/
public Listener(Graphics g,JPanel jp,TextArea jt){
this.g=(Graphics2D) g;
this.g.setStroke(new BasicStroke(10.0f));
this.jp=jp;
this.jt=jt;
}
/**
* 选择要识别的数字
*/
public void itemStateChanged(ItemEvent e){
number=Integer.parseInt((String)e.getItem());
System.out.println(number);
}
/**
* 选择模式
*/
public void actionPerformed(ActionEvent e){
if(e.getActionCommand().equals("识别"))
name="识别";
else if(e.getActionCommand().equals("训练")){
name="训练";
}
}
/**
* 画曲线
*/
public void mouseDragged(MouseEvent e){
g.drawLine(x1, y1, e.getX(), e.getY());
//求二维数组下标
int index_x=x1/10;
int index_y=y1/10;
//用四舍五入的方法判断是落在哪个格子(点)中
if(x1%10>4)
index_x++;
if(y1%10>4)
index_y++;
for(int i=-1;i<2;i++){
if(index_x+i>=0&&index_x+i<Area.arraySize)
draw[index_y][index_x+i]=1;
if(index_y+i>=0&&index_y+i<Area.arraySize)
draw[index_y+i][index_x]=1;
}
x1=e.getX();
y1=e.getY();
}
/**
* Invoked when the mouse cursor has been moved onto a component
* but no buttons have been pushed.
*/
public void mouseMoved(MouseEvent e){}
public void mouseClicked(MouseEvent e){
//输出记录的矩阵,检查是否符合预期
// for(int i=0;i<Area.arraySize;i++){
// for(int j=0;j<Area.arraySize;j++){
// System.out.print(draw[i][j]);
// }
// System.out.println(" ");
// }
// System.out.println("_____________________________________________");
//if(e.getClickCount()==2){
//训练模式直接写入保存就可以
if(name.equals("训练")){
try{
out();
}catch(IOException e0){
e0.printStackTrace();
}
}
//识别模式
if(name.equals("识别")){
try{
//读取数据
in();
}catch(IOException e0){
e0.printStackTrace();
}
//(以上得到距离数组)根据距离排序
// System.out.println(distance.getsize()+"大小");
for(int i=0;i<distance.getsize();i++){
for(int j=i;j<distance.getsize();j++){
if(distance.getObject(i).dis>distance.getObject(j).dis)
distance.swap(i, j);
}
}
//得出K个临近的数字出现的频率
for(int i=0;i<Area.K;i++){
for(int j=0;j<10;j++){
// System.out.println("K计数的数字:"+distance.getObject(i).number);
if(distance.getObject(i).number==j)
count[j]++;
}
}
int result=0,max=0;//得出最大频率的数字就是结果
for(int i=0;i<10;i++){
// System.out.println("最大频率:"+max);
if(max<count[i]){
max=count[i];
result=i;
}
}
//不确定是否K个中只识别到一个数字,所以是可能
System.out.println("识别的数字是:"+result);
jt.append("识别的数字可能为:"+result);
//如果前K个数字中也识别到其他的数字,也把他们输出
for(int i=0;i<10;i++){
if(count[i]>0&&result!=i)
jt.append(" "+i);
}
jt.append("\n");
}
//本次操作结束,重绘且更新数组
jp.paint(g);
init();
//}
}
/**
* Invoked when a mouse button has been pressed on a component.
*/
public void mousePressed(MouseEvent e){
x1=e.getX();
y1=e.getY();
}
/**
* Invoked when a mouse button has been released on a component.
*/
public void mouseReleased(MouseEvent e){}
/**
* Invoked when the mouse enters a component.
*/
public void mouseEntered(MouseEvent e){}
/**
* Invoked when the mouse exits a component.
*/
public void mouseExited(MouseEvent e){}
/*
* 保存训练的数据
*/
public void out() throws IOException{
//创建文件名为:保存为数字+当前时间
File ins=new File("D:/learning/mydemo/Javaworkspace/rect/src/com/手写识别/"+number+'-'+System.currentTimeMillis());
//写入文件
FileWriter write=new FileWriter(ins);
for(int i=0;i<Area.arraySize;i++){
for(int j=0;j<Area.arraySize;j++){
write.write(draw[i][j]+"");
}
}
write.flush();
write.close();
}
/**
* 读取数据
*/
public void in() throws IOException{
File file =new File ("D:/learning/mydemo/Javaworkspace/rect/src/com/手写识别/");
//得到该目录下的所有文件名
String[] filelist=file.list();
for(int i=0;i<filelist.length;i++){
FileReader readfile=new FileReader("D:/learning/mydemo/Javaworkspace/rect/src/com/手写识别/"+filelist[i]);
//读取文件,输入流读取返回int
int b=readfile.read();
for(int k=0;k<Area.arraySize;k++){
for(int j=0;j<Area.arraySize;j++){
//文章末为-1
if(b!=-1){
//从文件名中取出的数存入数组
getfile[k][j]=Integer.valueOf((char)b+"");
//读取下一个
b=readfile.read();
}
}
}
//关闭文件
readfile.close();
//提取文件名首数字
int a =Integer.valueOf(filelist[i].substring(0, 1));
// System.out.println("本次检查数字:"+a);
//调用计算距离公式,得到的距离传给距离数组
setdistance(getfile,draw,a);
}
}
/**
* 计算当前写的数字与样本的欧式距离
* @param a 32*32数组
* @param b 32*32数组
* @param n 数字
*/
public void setdistance(int [][]a,int [][]b,int n){
int sum=0;
for(int i=0;i<Area.arraySize;i++){
for(int j=0;j<Area.arraySize;j++){
sum+=(int)Math.pow((double)(a[i][j]-b[i][j]), 2);
}
}
sum=(int)Math.sqrt((double)sum);
Distance d=new Distance(sum,n);
distance.add(d);
// System.out.println(n+"数字距离为:"+sum);
}
/**
* 内部保存距离的类
* 有距离,数字,构造函数
* @author mo
*
*/
class Distance{
public int dis;
public int number;
public Distance(int d,int n){
this.dis=d;
this.number=n;
}
public Distance(){}
}
}