# 机器学习--knn手写数字识别系统

1143人阅读 评论(0)

0.k近邻算法

1.knn算法简介：

2.该程序的功能主要有如下几个，

功能1:可以在面板上手写输入数字

功能2:可以对特定的区域进行截屏，因为要获取用户手写的数字，保存为图像，然后使用算法进行分析

功能3:可以对图片进行缩放，要保证图片的大小（维度）要和数据集中的大小一样。

功能4:可以将彩色图片转化为二值图片

功能5:对图片中的手写数字使用KNN算法进行识别，也可以在测试集上计算算法的准确性。

（演示）

class Board extends JPanel implements MouseMotionListener {
final private int boardWidth = 320;
final private int boardHeight = 320;
final private int boardX = 1;
final private int boardY = 1;
private int pencilWidth = 40;

public void paint(Graphics graphics) {
super.paint(graphics);
graphics.setColor(Color.BLACK);
graphics.draw3DRect(this.boardX - 1, this.boardY - 1, this.boardWidth + 1, this.boardHeight + 1, true);
graphics.setColor(Color.WHITE);
graphics.fill3DRect(this.boardX, this.boardY, this.boardWidth, this.boardHeight, true);
}

@Override
public void mouseDragged(MouseEvent e) {
// TODO Auto-generated method stub
Graphics graphics = this.getGraphics();
if (e.getX() > 1 && e.getX() < boardWidth - this.pencilWidth && e.getY() > 1
&& e.getY() < boardHeight - pencilWidth)
graphics.fillOval(e.getX(), e.getY(), pencilWidth, pencilWidth);
}

@Override
public void mouseMoved(MouseEvent e) {
// TODO Auto-generated method stub

}
}

class ScreenShot {
private int startX;
private int startY;
private int width;
private int height;
private String saveTo;

public ScreenShot(int startX, int startY, int width, int height, String filename) {
this.startX = startX;//截取的起始x坐标
this.startY = startY;//截取的起始y坐标
this.width = width;  //截取的宽度
this.height = height;//截取的高度
this.saveTo = ".\\" + filename + ".png";//图片的保存位置
}

public void capture() {
File file = new File(saveTo);
try {
BufferedImage bufferedImage = (new Robot())
.createScreenCapture(new Rectangle(startX, startY, width, height));
ImageIO.write(bufferedImage, "png", file);
System.out.println("capture image has finish...");
} catch (AWTException e) {
// TODO Auto-generated catch block
e.printStackTrace();
} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
}

class ZoomImage {
private String filename;
private float scaling;

public ZoomImage(String filename, float scaling) {
this.filename = filename;//scaling为缩放比例，在这里是缩小的比例
this.scaling = scaling;
}

public void zoom() {
File file = new File(this.filename);
try {
BufferedImage bufferedImage2 = new BufferedImage((int) (this.scaling * bufferedImage1.getWidth()),
(int) (this.scaling * bufferedImage1.getHeight()), BufferedImage.TYPE_INT_BGR);
Graphics graphics = bufferedImage2.createGraphics();
graphics.drawImage(bufferedImage1, 0, 0, (int) (this.scaling * bufferedImage1.getWidth()),
(int) (this.scaling * bufferedImage1.getHeight()), null);
ImageIO.write(bufferedImage2, "png", new File(".\\zoominMaggie.png"));
System.out.println("image has been zoomed...");
} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
}

class RGB2binary {
private String filename;
private short[] userInputDigit = new short[32 * 32];

public short[] getUserInputDigit() {
return this.userInputDigit;
}

public RGB2binary(String filename) {
this.filename = filename;
}

public void rgb2binary() {
System.out.println(this.filename);
File file = new File(this.filename);
try {
int startX = bufferedImage.getMinX();
int startY = bufferedImage.getMinY();
int width = bufferedImage.getWidth();
int height = bufferedImage.getHeight();
System.out.println("x = " + startX + " y = " + startY + " width = " + width + " height = " + height);
for (int i = startX; i < width; i++) {
for (int j = startY; j < height; j++) {
int pixel = bufferedImage.getRGB(j, i);
int r = (pixel & 0xff0000) >> 16;//得到该像素点的R值
int g = (pixel & 0xff00) >> 8;
int b = (pixel & 0xff);
float gray = r * 0.3f + g * 0.59f + b * 0.11f;//灰度变为二值的计算公式
if (gray > 128) {
System.out.print(0 + "");
userInputDigit[i * width + j] = 0;
} else {
System.out.print(1 + "");
userInputDigit[i * width + j] = 1;
}
}
System.out.println();
}
} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
}

class Knn {
private int featureSize = 32 * 32;
String trainingSetDir = "./trainingDigits";
String testSetDir = "./testDigits1";

private int trainingSetSize;
private int testSetSize;
private short[][] trainingData = null;
private short[] trainintSetLabel = null;
private short[][] testData = null;
private short[] testSetLabel = null;

public Knn() {

}

//读取训练集
File path = new File(trainingSetDir);
File files[] = path.listFiles();

System.out.println("total file number: " + files.length);

this.trainingSetSize = files.length;
trainingData = new short[trainingSetSize][32 * 32];
trainintSetLabel = new short[trainingSetSize];

int fileCount = 0;
for (File file : files) {

String[] filename = file.getName().split("_");
trainintSetLabel[fileCount] = Short.parseShort(String.valueOf(filename[0]));

int lines = 0;
char buff[] = new char[32 + 2];		//为什么要+2：因为要读取文件末尾的换行和回车
int count = 0;
try {
for( int i = 0; i < 32; i++ )
trainingData[fileCount][lines * 32 + i] = Short.parseShort(String.valueOf(buff[i]));
lines++;
}
} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
fileCount++;
}
}

//读取测试集
{
File path = new File(testSetDir);
File[] files = path.listFiles();

System.out.println("total number of test file" + files.length);
this.testSetSize = files.length;
testData = new short[this.testSetSize][32 * 32];
testSetLabel = new short[this.testSetSize];

int fileCount = 0;
for( File file : files )
{
String[] filename = file.getName().split("_");
testSetLabel[fileCount] = Short.parseShort(String.valueOf(filename[0]));

try {
int count = 0;
int lines = 0;
char buff[] = new char[32 + 2];
{
for( int i = 0; i < 32; i++ )
testData[fileCount][lines * 32 + i] = Short.parseShort(String.valueOf(buff[i]));
lines++;
}
fileCount++;
} catch (FileNotFoundException e) {
// TODO Auto-generated catch block
e.printStackTrace();
} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
}

//@feature,待判断的实例的特征向量，
//@k,即为knn算法中的k
//返回分类的结果
public int knn(short[] feature, int k)
{
double[] distances = new double[this.trainingSetSize];

for( int i = 0; i < trainingSetSize; i++ )
distances[i] = calculateDistance(feature, trainingData[i]);

int[] argDistance = this.arg_sort(distances);

HashMap<Short, Integer> vote = new HashMap<>();
for( int i = 0; i < k; i++ )
{
if ( null == vote.get(trainintSetLabel[argDistance[i]]) )
vote.put(trainintSetLabel[argDistance[i]], 1);
else
{
int score = vote.get(trainintSetLabel[argDistance[i]]) + 1;
vote.put(trainintSetLabel[argDistance[i]], score);
}
}
int result = 0;
int maxVote = 0;
for( short key : vote.keySet() )
{
if( maxVote < vote.get(key) )
{
result = key;
maxVote = vote.get(key);
}
}
return result;
}

//在测试集上计算该算法的准确性
public double knnPrecise()
{

int success = 0;
for( int i = 0; i < testSetSize; i++ )
if( testSetLabel[i] == knn(testData[i], 3) )
success++;
return (double)success/testSetSize;
}

public double calculateDistance(short[] sequcence1, short[] sequence2)
{
int distance = 0;
for( int i = 0; i < sequcence1.length; i++ )
distance += (sequcence1[i] - sequence2[i]) * (sequcence1[i] - sequence2[i]);
return Math.sqrt(distance);
}

//返回的是sequence升序的下标序列
public int[] arg_sort(double[] sequence)
{
double[] sequence1 = sequence.clone();

int[] indexOfSequence = new int[sequence.length];
for( int i = 0; i < sequence1.length; i++ )
indexOfSequence[i] = i;

double minValue, tempD;
int minIndex,tempI;
for( int i = 0; i < sequence1.length - 1; i++ )
{
minValue = sequence1[i];
minIndex = i;
for( int j = i + 1; j < sequence1.length; j++ )
{
if( sequence1[j] < minValue )
{
minValue = sequence1[j];
minIndex = j;
}
}
if( i != minIndex )
{
tempD = sequence1[minIndex];
tempI = indexOfSequence[minIndex];
sequence1[minIndex] = sequence1[i];
indexOfSequence[minIndex] = indexOfSequence[i];
sequence1[i] = tempD;
indexOfSequence[i] = tempI;
}
}
return indexOfSequence;
}

public int getTrainingSetSize() {
return trainingSetSize;
}
public int getTestSetSize() {
return testSetSize;
}
}

3.结果：

在测试集上的准确性很高，但是实际应用中却远没有那么高。

import java.awt.AWTException;
import java.awt.Color;
import java.awt.Font;
import java.awt.Graphics;
import java.awt.Rectangle;
import java.awt.Robot;
import java.awt.event.ActionEvent;
import java.awt.event.ActionListener;
import java.awt.event.MouseEvent;
import java.awt.event.MouseMotionListener;
import java.awt.image.BufferedImage;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.util.HashMap;

import javax.imageio.ImageIO;
import javax.swing.JButton;
import javax.swing.JFrame;
import javax.swing.JLabel;
import javax.swing.JPanel;
import javax.swing.JTextArea;

public class Recognition extends JFrame implements ActionListener {

final private int windowWidth = 493;
final private int windowHeight = 380;
final private int windowX = 100;
final private int windowY = 100;

Board board = null;
JButton reWriteButton = null;
JButton recognitionButton = null;
JButton testButton = null;
JTextArea showResult = null;

private int contentPaneX;
private int contentPaneY;

public Recognition() {
board = new Board();
this.setLayout(null);
board.setBounds(8, 8, 332, 332);

reWriteButton = new JButton("Rewrite");
reWriteButton.setBounds(340, 10, 130, 30);

recognitionButton = new JButton("Recognition");
recognitionButton.setBounds(340, 40, 130, 30);

testButton = new JButton("testPrecise");
testButton.setBounds(340, 80, 130, 30);

showResult = new JTextArea();
showResult.setOpaque(true);
showResult.setBackground(Color.CYAN);
showResult.setForeground(Color.BLACK);
showResult.setFont(new Font("微软雅黑", Font.BOLD, 12));
showResult.setLineWrap(true);
showResult.setBounds(340, 180, 130, 150);
showResult.setVisible(false);
this.setTitle("HandWriting Recognition");

this.setSize(windowWidth, windowHeight);
this.setLocation(windowX, windowY);
this.setVisible(true);
this.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
}

public static void main(String[] args) {
// TODO Auto-generated method stub
Recognition recognition = new Recognition();
}

@Override
public void actionPerformed(ActionEvent e) {
// TODO Auto-generated method stub
if (e.getSource() == reWriteButton) {
repaint();
} else if (e.getSource() == recognitionButton) {
this.contentPaneX = (int) this.getContentPane().getLocationOnScreen().getX();
this.contentPaneY = (int) this.getContentPane().getLocationOnScreen().getY();
ScreenShot screenShot = new ScreenShot(contentPaneX + 9, contentPaneY + 9, 320, 320, "maggie");
screenShot.capture();
ZoomImage zoomImage = new ZoomImage("./maggie.png", 0.1f);
zoomImage.zoom();
RGB2binary rgb2binary = new RGB2binary("./zoominMaggie.png");
rgb2binary.rgb2binary();
short[] userInput = rgb2binary.getUserInputDigit();

Knn knn = new Knn();
int recognitionResult = knn.knn(userInput, 3);
System.out.println("recognitionResult:"+ recognitionResult);
showResult.setText("Your input is \r\n" + String.valueOf(recognitionResult));
showResult.setVisible(true);
} else if ( e.getSource() == testButton ){
Knn knn = new Knn();
double precise = knn.knnPrecise();

String string = "Training Set Size is :\r\n" + knn.getTrainingSetSize() + "\r\nTest Set Size is :\r\n" + knn.getTestSetSize() + "\r\nAccury is \r\n" + String.valueOf(precise);
showResult.setText(string);
showResult.setVisible(true);
}
}
}

class Board extends JPanel implements MouseMotionListener {
final private int boardWidth = 320;
final private int boardHeight = 320;
final private int boardX = 1;
final private int boardY = 1;
private int pencilWidth = 40;

public void paint(Graphics graphics) {
super.paint(graphics);
graphics.setColor(Color.BLACK);
graphics.draw3DRect(this.boardX - 1, this.boardY - 1, this.boardWidth + 1, this.boardHeight + 1, true);
graphics.setColor(Color.WHITE);
graphics.fill3DRect(this.boardX, this.boardY, this.boardWidth, this.boardHeight, true);
}

@Override
public void mouseDragged(MouseEvent e) {
// TODO Auto-generated method stub
Graphics graphics = this.getGraphics();
if (e.getX() > 1 && e.getX() < boardWidth - this.pencilWidth && e.getY() > 1
&& e.getY() < boardHeight - pencilWidth)
graphics.fillOval(e.getX(), e.getY(), pencilWidth, pencilWidth);
}

@Override
public void mouseMoved(MouseEvent e) {
// TODO Auto-generated method stub

}
}

class ScreenShot {
private int startX;
private int startY;
private int width;
private int height;
private String saveTo;

public ScreenShot(int startX, int startY, int width, int height, String filename) {
this.startX = startX;
this.startY = startY;
this.width = width;
this.height = height;
this.saveTo = ".\\" + filename + ".png";
}

public void capture() {
File file = new File(saveTo);
try {
BufferedImage bufferedImage = (new Robot())
.createScreenCapture(new Rectangle(startX, startY, width, height));
ImageIO.write(bufferedImage, "png", file);
System.out.println("capture image has finish...");
} catch (AWTException e) {
// TODO Auto-generated catch block
e.printStackTrace();
} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
}

class ZoomImage {
private String filename;
private float scaling;

public ZoomImage(String filename, float scaling) {
this.filename = filename;
this.scaling = scaling;
}

public void zoom() {
File file = new File(this.filename);
try {
BufferedImage bufferedImage2 = new BufferedImage((int) (this.scaling * bufferedImage1.getWidth()),
(int) (this.scaling * bufferedImage1.getHeight()), BufferedImage.TYPE_INT_BGR);
Graphics graphics = bufferedImage2.createGraphics();
graphics.drawImage(bufferedImage1, 0, 0, (int) (this.scaling * bufferedImage1.getWidth()),
(int) (this.scaling * bufferedImage1.getHeight()), null);
ImageIO.write(bufferedImage2, "png", new File(".\\zoominMaggie.png"));
System.out.println("image has been zoomed...");
} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
}

class RGB2binary {
private String filename;
private short[] userInputDigit = new short[32 * 32];

public short[] getUserInputDigit() {
return this.userInputDigit;
}

public RGB2binary(String filename) {
this.filename = filename;
}

public void rgb2binary() {
System.out.println(this.filename);
File file = new File(this.filename);
try {
int startX = bufferedImage.getMinX();
int startY = bufferedImage.getMinY();
int width = bufferedImage.getWidth();
int height = bufferedImage.getHeight();
System.out.println("x = " + startX + " y = " + startY + " width = " + width + " height = " + height);
for (int i = startX; i < width; i++) {
for (int j = startY; j < height; j++) {
int pixel = bufferedImage.getRGB(j, i);
int r = (pixel & 0xff0000) >> 16;
int g = (pixel & 0xff00) >> 8;
int b = (pixel & 0xff);
float gray = r * 0.3f + g * 0.59f + b * 0.11f;
if (gray > 128) {
System.out.print(0 + "");
userInputDigit[i * width + j] = 0;
} else {
System.out.print(1 + "");
userInputDigit[i * width + j] = 1;
}
}
System.out.println();
}
} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
}

class Knn {
private int featureSize = 32 * 32;
String trainingSetDir = "./trainingDigits";
String testSetDir = "./testDigits1";

private int trainingSetSize;
private int testSetSize;
private short[][] trainingData = null;
private short[] trainintSetLabel = null;
private short[][] testData = null;
private short[] testSetLabel = null;

public Knn() {

}

//读取训练集
File path = new File(trainingSetDir);
File files[] = path.listFiles();

System.out.println("total file number: " + files.length);

this.trainingSetSize = files.length;
trainingData = new short[trainingSetSize][32 * 32];
trainintSetLabel = new short[trainingSetSize];

int fileCount = 0;
for (File file : files) {

String[] filename = file.getName().split("_");
trainintSetLabel[fileCount] = Short.parseShort(String.valueOf(filename[0]));

int lines = 0;
char buff[] = new char[32 + 2];		//为什么要+2：因为要读取文件末尾的换行和回车
int count = 0;
try {
for( int i = 0; i < 32; i++ )
trainingData[fileCount][lines * 32 + i] = Short.parseShort(String.valueOf(buff[i]));
lines++;
}
} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
fileCount++;
}
}

//读取测试集
{
File path = new File(testSetDir);
File[] files = path.listFiles();

System.out.println("total number of test file" + files.length);
this.testSetSize = files.length;
testData = new short[this.testSetSize][32 * 32];
testSetLabel = new short[this.testSetSize];

int fileCount = 0;
for( File file : files )
{
String[] filename = file.getName().split("_");
testSetLabel[fileCount] = Short.parseShort(String.valueOf(filename[0]));

try {
int count = 0;
int lines = 0;
char buff[] = new char[32 + 2];
{
for( int i = 0; i < 32; i++ )
testData[fileCount][lines * 32 + i] = Short.parseShort(String.valueOf(buff[i]));
lines++;
}
fileCount++;
} catch (FileNotFoundException e) {
// TODO Auto-generated catch block
e.printStackTrace();
} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
}

//@feature,待判断的实例的特征向量，
//@k,即为knn算法中的k
//返回分类的结果
public int knn(short[] feature, int k)
{
double[] distances = new double[this.trainingSetSize];

for( int i = 0; i < trainingSetSize; i++ )
distances[i] = calculateDistance(feature, trainingData[i]);

int[] argDistance = this.arg_sort(distances);

HashMap<Short, Integer> vote = new HashMap<>();
for( int i = 0; i < k; i++ )
{
if ( null == vote.get(trainintSetLabel[argDistance[i]]) )
vote.put(trainintSetLabel[argDistance[i]], 1);
else
{
int score = vote.get(trainintSetLabel[argDistance[i]]) + 1;
vote.put(trainintSetLabel[argDistance[i]], score);
}
}
int result = 0;
int maxVote = 0;
for( short key : vote.keySet() )
{
if( maxVote < vote.get(key) )
{
result = key;
maxVote = vote.get(key);
}
}
return result;
}

//在测试集上计算该算法的准确性
public double knnPrecise()
{

int success = 0;
for( int i = 0; i < testSetSize; i++ )
if( testSetLabel[i] == knn(testData[i], 3) )
success++;
return (double)success/testSetSize;
}

public double calculateDistance(short[] sequcence1, short[] sequence2)
{
int distance = 0;
for( int i = 0; i < sequcence1.length; i++ )
distance += (sequcence1[i] - sequence2[i]) * (sequcence1[i] - sequence2[i]);
return Math.sqrt(distance);
}

//返回的是sequence升序的下标序列
public int[] arg_sort(double[] sequence)
{
double[] sequence1 = sequence.clone();

int[] indexOfSequence = new int[sequence.length];
for( int i = 0; i < sequence1.length; i++ )
indexOfSequence[i] = i;

double minValue, tempD;
int minIndex,tempI;
for( int i = 0; i < sequence1.length - 1; i++ )
{
minValue = sequence1[i];
minIndex = i;
for( int j = i + 1; j < sequence1.length; j++ )
{
if( sequence1[j] < minValue )
{
minValue = sequence1[j];
minIndex = j;
}
}
if( i != minIndex )
{
tempD = sequence1[minIndex];
tempI = indexOfSequence[minIndex];
sequence1[minIndex] = sequence1[i];
indexOfSequence[minIndex] = indexOfSequence[i];
sequence1[i] = tempD;
indexOfSequence[i] = tempI;
}
}
return indexOfSequence;
}

public int getTrainingSetSize() {
return trainingSetSize;
}
public int getTestSetSize() {
return testSetSize;
}
}

0
0

个人资料
• 访问：219606次
• 积分：3189
• 等级：
• 排名：第12760名
• 原创：95篇
• 转载：0篇
• 译文：0篇
• 评论：65条
最新评论