KNN算法和欧式距离介绍
1. KNN算法又称为k近邻分类(k-nearest neighbor classification)算法。
最简单平凡的分类器也许是那种死记硬背式的分类器,记住所有的训练数据,对于新的数据则直接和训练数据匹配,如果存在相同属性的训练数据,则直接用它的分类来作为新数据的分类。这种方式有一个明显的缺点,那就是很可能无法找到完全匹配的训练记录。
KNN算法则是从训练集中找到和新数据最接近的k条记录,然后根据他们的主要分类来决定新数据的类别。该算法涉及3个主要因素:训练集、距离或相似的衡量、k的大小。
训练模式得到每一份形状(后转变成数组)后,输出存储。将识别模式下得到的形状(后转变成数组)和取出的文件相比,算出距离distance(K)存入数组,得到最小的distance(K)输出对应数字。
2.欧式距离
二维空间a(x1,y1),b(x2,y2)两点间的距离k = sqrt((x1-x2)^2+(y1-y2)^2)
可拓展三维,四维,n维......
操作:如图界面,选择训练单选按钮进入训练模式,选择下拉框数字决定训练数字,单击输入区域清屏并且保存样本。选择
识别按钮进入识别模式,写完数字后单击输入区域清屏并且弹出对话框,显示识别结果。
1.界面设计
界面很简单,继承JFrame,流式布局。添加两个单选按钮组件(JRadioButton)区分识别模式和训练模式。注意要将两耳光单选按钮加入ButtonGroup之中,保证只有一个按钮会在同一时间被选用。添加下拉框组件(JComboBox)区分训练的数字。添加JPanel为手写区域,防止数字写在按钮上,设置黑色方便区分。
public void showUI() {
this.setTitle("简单识别");
//只有JFrame可以使用setsize,jJPanel要使用setPreferredSize
this.setSize(600, 700);
this.setDefaultCloseOperation(3);
this.setLocationRelativeTo(null);
this.setLayout(new FlowLayout());
//创建两个单选按钮
JRadioButton jrb1 = new JRadioButton ("识别");
JRadioButton jrb2 = new JRadioButton ("训练");
//创建按钮组,将两个单选按钮添加进按钮组
ButtonGroup bg = new ButtonGroup();
bg.add(jrb1);
bg.add(jrb2);
//设置第一个单选按钮选中
jrb1.setSelected(true);
//将单选按钮加入界面
this.add(jrb1);
this.add(jrb2);
//添加JComboBox组件,区分训练数字
String[] label = {"0","1","2","3","4","5","6","7","8","9"};
JComboBox jcb = new JComboBox(label);
this.add(jcb);
JPanel jp = new JPanel();
jp.setPreferredSize(new Dimension(500,570));
jp.setBackground(Color.DARK_GRAY);
this.add(jp);
this.setVisible(true);
Graphics g = jp.getGraphics();
//继承MouseAdapter时,添加监听器时要将需要使用的MouseListener,MouseMotionListener添加
Listener l = new Listener(g,jcb,this,jrb1,jrb2);
jp.addMouseListener(l);
jp.addMouseMotionListener(l);
jrb1.addActionListener(l);
jrb2.addActionListener(l);
jcb.addActionListener(l);
}
2.将写入的数字转化为二维数组,将其简化为一个25*25的二维数组区域,每一个在单独二维数组区域得到的点都被固定在二维数组区域(类似五子棋的落子确定点,飞机大战的碰撞规则),鼠标拖动监听得到鼠标经过的位置,将其二维数组标记为1,没经过的为0。可得到类似二维数组(图为0)
3.存取文件(i/0流)
因为KNN比较笨,所以需要更多的样本进行比较,因此离不开文件的读取操作。
将二维数组存入文件时,要区分存入的文件数字,同时要保证文件名不同,可以采用“手写数字”+“系统时间”的命名方式。获取系统时间的方法是System.currentTimeMillis()
// 按钮选择识别模式或者训练模式
@Override
public void actionPerformed(ActionEvent e) {
if (jrb1.isSelected()) {
choose = 1;
System.out.println("识别");
}else if (jrb2.isSelected()) {
// 选择训练模式
choose = 2;
System.out.println("训练");
String name = e.getActionCommand();
if (name.equals("comboBoxChanged")) {
// System.out.println("监听训练数字");
// System.out.println(name);
String str = jcb.getSelectedItem().toString();
System.out.println(str);
num = Integer.parseInt(str); // 字符串转化为整型
// System.out.println(num);
}
}
}
// 输出到文件
public void out() throws IOException {
String fileName = "D:\\Java Class\\io输出\\手写识别\\" + num + "-" + System.currentTimeMillis() + ".txt";
File file = new File(fileName); // 创建文件
FileWriter out = new FileWriter(file);// FileWriter 用来写入字符文件的便捷类
for (int i = 0; i < SIZE; i++) {
for (int j = 0; j < SIZE; j++) {
out.write(array[i][j] + "");
}
}
out.flush();
out.close();
}
识别时要进行取文件,读取文件时需要分辨出文件样本对应的数字,由于存入时用的是数字命名,这里可以采用subString的方法获取文件首字符,并转换为数字。
// 读取文件
public void in() throws IOException {
File file = new File("D:\\Java Class\\io输出\\手写识别\\");
String[] filelist = file.list();// 得到该目录下所有文件名
// FileReader 用来读取字符文件的便捷类
for (int i = 0; i < filelist.length; i++) {
// 读取目录名
FileReader readfile = new FileReader("D:\\Java Class\\io输出\\手写识别\\" + filelist[i]);
// 读取文件,输入流读取返回int
int b = readfile.read();
int m = 0, n = 0;
//read读取下一个数据字节,如果到达流末尾,则返回 -1
while (b != -1) {
// 将从文件名中取出的数存入数组
int c = Integer.valueOf((char)b+"");
getfilename[m][n] = c;
n++;
if (n == SIZE) {
n = 0;
m++;
}
// 继续读取下一个文件字符
b = readfile.read();
}
//关闭文件
readfile.close();
// 提取文件名首数字
int a = Integer.valueOf(filelist[i].substring(0, 1));
// 调用计算距离的公式,将得到的距离等传给对应的数组变量
getdistance(getfilename, array, a);
}
}
4.计算距离进行识别操作
计算距离以后进行距离排序,识别输出距离最小值,也最有可能为被识别数字。样本越多识别越准确。
@Override
public void mouseClicked(MouseEvent e) {
System.out.println("鼠标点击了");
if (e.getClickCount() == 2||e.getButton() == 1) {
// 双击重绘
jf.repaint();
}
if (choose == 1) {
System.out.println("开始识别");
// 识别模式
try {
// 调用读取数据方法
in();
} catch (IOException e1) {
e1.printStackTrace();
}
// 将所有求出的欧式距离从小到大排序
for (int i = 0; i < number.size() - 1; i++) {
for (int j = 0; j < number.size() - 1 - i; j++) {// -1防止溢出
if (distance.get(j) > distance.get(j + 1)) {
int temp = distance.get(j);
distance.set(j, distance.get(j + 1));
distance.set(j + 1, temp);
int tempnum = number.get(j);
number.set(j, number.get(j + 1));
number.set(j + 1, tempnum);
}
}
}
// 取距离最短的三个
int result = number.get(0);
if (number.get(1) == number.get(2)) {
result = number.get(1);
}
JOptionPane.showMessageDialog(null, "输入的数字是:" + result);
}
}
// 如果处于训练模式
if (choose == 2) {
System.out.println("开始训练");
try {
// 调用输出数据方法
out();
} catch (IOException e1) {
e1.printStackTrace();
}
}
// 打印输入的矩阵,并清空初始化
for (int i = 0; i < SIZE; i++) {
for (int j = 0; j < SIZE; j++) {
array[j][i] = 0;
}
}
number.clear();
distance.clear();
}