关闭

机器学习--knn手写数字识别系统

标签: 机器学习knn手写数字识别java截取特定位置并保存java实现图片缩放
1143人阅读 评论(0) 收藏 举报
分类:

0.k近邻算法

刚接触java,并且在学习机器学习的相关算法,knn又非常的易于实现,于是就有了这个小系统。

1.knn算法简介:

存在一个样本数据集合,也称为训练样本集,并且样本集中的每一个数据都有标签,即我们知道样本集中的每一个数据的特征和对应的类型。当输入没有标签的新的数据的时候,将新的数据集的每一个特征和样本集中的每一个数据的对应的特征进行比较(计算两个样本的特征之间的距离),然后提取样本集中和输入的新数据特征最相似的数据的类的标签,通常我们只关心前k个最相似的数据,这就是k近算法中的k的出处。一般来说,我们只选择样本数据集中的前k最相似的数据,然后选择k个最相似的数据集中出现次数最多的作为新数据的分类。

2.该程序的功能主要有如下几个,

      功能1:可以在面板上手写输入数字

      功能2:可以对特定的区域进行截屏,因为要获取用户手写的数字,保存为图像,然后使用算法进行分析

      功能3:可以对图片进行缩放,要保证图片的大小(维度)要和数据集中的大小一样。

      功能4:可以将彩色图片转化为二值图片

      功能5:对图片中的手写数字使用KNN算法进行识别,也可以在测试集上计算算法的准确性。

(演示)



功能1的实现代码:手写板

创建一个JPane类的子类,通过监听mouseDragged事件,调用graphics来实现手写板的功能。

class Board extends JPanel implements MouseMotionListener {
	final private int boardWidth = 320;
	final private int boardHeight = 320;
	final private int boardX = 1;
	final private int boardY = 1;
	private int pencilWidth = 40;

	public void paint(Graphics graphics) {
		super.paint(graphics);
		graphics.setColor(Color.BLACK);
		graphics.draw3DRect(this.boardX - 1, this.boardY - 1, this.boardWidth + 1, this.boardHeight + 1, true);
		graphics.setColor(Color.WHITE);
		graphics.fill3DRect(this.boardX, this.boardY, this.boardWidth, this.boardHeight, true);
	}

	@Override
	public void mouseDragged(MouseEvent e) {
		// TODO Auto-generated method stub
		Graphics graphics = this.getGraphics();
		if (e.getX() > 1 && e.getX() < boardWidth - this.pencilWidth && e.getY() > 1
				&& e.getY() < boardHeight - pencilWidth)
			graphics.fillOval(e.getX(), e.getY(), pencilWidth, pencilWidth);
	}

	@Override
	public void mouseMoved(MouseEvent e) {
		// TODO Auto-generated method stub

	}
}

功能2的实现:可以对特定的区域进行截屏

class ScreenShot {
	private int startX;
	private int startY;
	private int width;
	private int height;
	private String saveTo;

	public ScreenShot(int startX, int startY, int width, int height, String filename) {
		this.startX = startX;//截取的起始x坐标
		this.startY = startY;//截取的起始y坐标
		this.width = width;  //截取的宽度
		this.height = height;//截取的高度
		this.saveTo = ".\\" + filename + ".png";//图片的保存位置
	}

	public void capture() {
		File file = new File(saveTo);
		try {
			BufferedImage bufferedImage = (new Robot())
					.createScreenCapture(new Rectangle(startX, startY, width, height));
			ImageIO.write(bufferedImage, "png", file);
			System.out.println("capture image has finish...");
		} catch (AWTException e) {
			// TODO Auto-generated catch block
			e.printStackTrace();
		} catch (IOException e) {
			// TODO Auto-generated catch block
			e.printStackTrace();
		}
	}
}

功能3:可以对图片进行缩放,要保证图片的大小(维度)要和数据集中的大小一样。

class ZoomImage {
	private String filename;
	private float scaling;

	public ZoomImage(String filename, float scaling) {
		this.filename = filename;//scaling为缩放比例,在这里是缩小的比例
		this.scaling = scaling;
	}

	public void zoom() {
		File file = new File(this.filename);
		try {
			BufferedImage bufferedImage1 = ImageIO.read(new File(filename));
			BufferedImage bufferedImage2 = new BufferedImage((int) (this.scaling * bufferedImage1.getWidth()),
					(int) (this.scaling * bufferedImage1.getHeight()), BufferedImage.TYPE_INT_BGR);
			Graphics graphics = bufferedImage2.createGraphics();
			graphics.drawImage(bufferedImage1, 0, 0, (int) (this.scaling * bufferedImage1.getWidth()),
					(int) (this.scaling * bufferedImage1.getHeight()), null);
			ImageIO.write(bufferedImage2, "png", new File(".\\zoominMaggie.png"));
			System.out.println("image has been zoomed...");
		} catch (IOException e) {
			// TODO Auto-generated catch block
			e.printStackTrace();
		}
	}
}

功能4:可以将彩色图片转化为二值图片

class RGB2binary {
	private String filename;
	private short[] userInputDigit = new short[32 * 32];

	public short[] getUserInputDigit() {
		return this.userInputDigit;
	}

	public RGB2binary(String filename) {
		this.filename = filename;
	}

	public void rgb2binary() {
		System.out.println(this.filename);
		File file = new File(this.filename);
		try {
			BufferedImage bufferedImage = ImageIO.read(file);
			int startX = bufferedImage.getMinX();
			int startY = bufferedImage.getMinY();
			int width = bufferedImage.getWidth();
			int height = bufferedImage.getHeight();
			System.out.println("x = " + startX + " y = " + startY + " width = " + width + " height = " + height);
			for (int i = startX; i < width; i++) {
				for (int j = startY; j < height; j++) {
					int pixel = bufferedImage.getRGB(j, i);
					int r = (pixel & 0xff0000) >> 16;//得到该像素点的R值
					int g = (pixel & 0xff00) >> 8;
					int b = (pixel & 0xff);
					float gray = r * 0.3f + g * 0.59f + b * 0.11f;//灰度变为二值的计算公式
					if (gray > 128) {
						System.out.print(0 + "");
						userInputDigit[i * width + j] = 0;
					} else {
						System.out.print(1 + "");
						userInputDigit[i * width + j] = 1;
					}
				}
				System.out.println();
			}
		} catch (IOException e) {
			// TODO Auto-generated catch block
			e.printStackTrace();
		}
	}
}

功能5:对图片中的手写数字使用KNN算法进行识别,也可以在测试集上计算算法的准确性。

class Knn {
	private int featureSize = 32 * 32;
	String trainingSetDir = "./trainingDigits";
	String testSetDir = "./testDigits1";

	private int trainingSetSize;
	private int testSetSize;
	private short[][] trainingData = null;
	private short[] trainintSetLabel = null;
	private short[][] testData = null;
	private short[] testSetLabel = null;

	public Knn() {

	}

	//读取训练集
	public void readTrainingSet() {
		File path = new File(trainingSetDir);
		File files[] = path.listFiles();
		
		System.out.println("total file number: " + files.length);
		
		this.trainingSetSize = files.length;
		trainingData = new short[trainingSetSize][32 * 32];
		trainintSetLabel = new short[trainingSetSize];
		
		int fileCount = 0;
		for (File file : files) {
			
			String[] filename = file.getName().split("_");
			trainintSetLabel[fileCount] = Short.parseShort(String.valueOf(filename[0]));
			
			int lines = 0;
			char buff[] = new char[32 + 2];		//为什么要+2:因为要读取文件末尾的换行和回车
			int count = 0;
			try {
				FileReader fileReader = new FileReader(file);
				while( -1 != (count = fileReader.read(buff)) ){
					for( int i = 0; i < 32; i++ )
						trainingData[fileCount][lines * 32 + i] = Short.parseShort(String.valueOf(buff[i]));
					lines++;
				}
				fileReader.close();
			} catch (IOException e) {
				// TODO Auto-generated catch block
				e.printStackTrace();
			}
			fileCount++;
		}
	}
	
	//读取测试集
	public void readTestSet()
	{
		File path = new File(testSetDir);
		File[] files = path.listFiles();
		
		System.out.println("total number of test file" + files.length);
		this.testSetSize = files.length;
		testData = new short[this.testSetSize][32 * 32];
		testSetLabel = new short[this.testSetSize];
		
		int fileCount = 0;
		for( File file : files )
		{
			String[] filename = file.getName().split("_");
			testSetLabel[fileCount] = Short.parseShort(String.valueOf(filename[0]));
			
			try {
				FileReader fileReader = new FileReader(file);
				int count = 0;
				int lines = 0;
				char buff[] = new char[32 + 2];
				while( -1 != (count = fileReader.read(buff)) )
				{
					for( int i = 0; i < 32; i++ )
						testData[fileCount][lines * 32 + i] = Short.parseShort(String.valueOf(buff[i]));
					lines++;
				}
				fileReader.close();
				fileCount++;
			} catch (FileNotFoundException e) {
				// TODO Auto-generated catch block
				e.printStackTrace();
			} catch (IOException e) {
				// TODO Auto-generated catch block
				e.printStackTrace();
			}
		}
	}
	
	//@feature,待判断的实例的特征向量,
	//@k,即为knn算法中的k
	//返回分类的结果
	public int knn(short[] feature, int k)
	{
		double[] distances = new double[this.trainingSetSize];

		for( int i = 0; i < trainingSetSize; i++ )
			distances[i] = calculateDistance(feature, trainingData[i]);
	
		int[] argDistance = this.arg_sort(distances);

		HashMap<Short, Integer> vote = new HashMap<>();
		for( int i = 0; i < k; i++ )
		{
			if ( null == vote.get(trainintSetLabel[argDistance[i]]) )
				vote.put(trainintSetLabel[argDistance[i]], 1);
			else
			{
				int score = vote.get(trainintSetLabel[argDistance[i]]) + 1;
				vote.put(trainintSetLabel[argDistance[i]], score);
			}
		}
		int result = 0;
		int maxVote = 0;
		for( short key : vote.keySet() )
		{
			if( maxVote < vote.get(key) )
			{
				result = key;
				maxVote = vote.get(key);
			}
		}
		return result;
	}
	
	//在测试集上计算该算法的准确性
	public double knnPrecise()
	{
		System.out.println("reading trainingSet...");
		this.readTrainingSet();
		System.out.println("reading trainingSet over");
		System.out.println("reading testSet...");
		this.readTestSet();
		System.out.println("reading testSet end");
		
		int success = 0;
		for( int i = 0; i < testSetSize; i++ )
			if( testSetLabel[i] == knn(testData[i], 3) )
				success++;
		return (double)success/testSetSize;
	}
	
	public double calculateDistance(short[] sequcence1, short[] sequence2)
	{
		int distance = 0;
		for( int i = 0; i < sequcence1.length; i++ )
			distance += (sequcence1[i] - sequence2[i]) * (sequcence1[i] - sequence2[i]);
		return Math.sqrt(distance);
	}
	
	//返回的是sequence升序的下标序列
	public int[] arg_sort(double[] sequence)
	{
		double[] sequence1 = sequence.clone();
		
		int[] indexOfSequence = new int[sequence.length];
		for( int i = 0; i < sequence1.length; i++ )
			indexOfSequence[i] = i;
		
		double minValue, tempD;
		int minIndex,tempI;
		for( int i = 0; i < sequence1.length - 1; i++ )
		{
			minValue = sequence1[i];
			minIndex = i;
			for( int j = i + 1; j < sequence1.length; j++ )
			{
				if( sequence1[j] < minValue )
				{
					minValue = sequence1[j];
					minIndex = j;
				}
			}
			if( i != minIndex )
			{
				tempD = sequence1[minIndex];
				tempI = indexOfSequence[minIndex];
				sequence1[minIndex] = sequence1[i];
				indexOfSequence[minIndex] = indexOfSequence[i];
				sequence1[i] = tempD;
				indexOfSequence[i] = tempI;
			}
		}
		return indexOfSequence;
	}
	
	
	public int getTrainingSetSize() {
		return trainingSetSize;
	}
	public int getTestSetSize() {
		return testSetSize;
	}
}

3.结果:

 在测试集上的准确性很高,但是实际应用中却远没有那么高。


完整代码

import java.awt.AWTException;
import java.awt.Color;
import java.awt.Font;
import java.awt.Graphics;
import java.awt.Rectangle;
import java.awt.Robot;
import java.awt.event.ActionEvent;
import java.awt.event.ActionListener;
import java.awt.event.MouseEvent;
import java.awt.event.MouseMotionListener;
import java.awt.image.BufferedImage;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.IOException;
import java.util.HashMap;

import javax.imageio.ImageIO;
import javax.swing.JButton;
import javax.swing.JFrame;
import javax.swing.JLabel;
import javax.swing.JPanel;
import javax.swing.JTextArea;

public class Recognition extends JFrame implements ActionListener {

	final private int windowWidth = 493;
	final private int windowHeight = 380;
	final private int windowX = 100;
	final private int windowY = 100;

	Board board = null;
	JButton reWriteButton = null;
	JButton recognitionButton = null;
	JButton testButton = null;
	JTextArea showResult = null;

	private int contentPaneX;
	private int contentPaneY;

	public Recognition() {
		board = new Board();
		this.setLayout(null);
		this.add(board);
		board.setBounds(8, 8, 332, 332);
		board.addMouseMotionListener(board);

		reWriteButton = new JButton("Rewrite");
		this.add(reWriteButton);
		reWriteButton.setBounds(340, 10, 130, 30);
		reWriteButton.addActionListener(this);

		recognitionButton = new JButton("Recognition");
		this.add(recognitionButton);
		recognitionButton.setBounds(340, 40, 130, 30);
		recognitionButton.addActionListener(this);
		
		testButton = new JButton("testPrecise");
		this.add(testButton);
		testButton.setBounds(340, 80, 130, 30);
		testButton.addActionListener(this);
		
		showResult = new JTextArea();
		showResult.setOpaque(true);
		showResult.setBackground(Color.CYAN);
		showResult.setForeground(Color.BLACK);
		showResult.setFont(new Font("微软雅黑", Font.BOLD, 12));
		showResult.setLineWrap(true);
		this.add(showResult);
		showResult.setBounds(340, 180, 130, 150);
		showResult.setVisible(false);
		this.setTitle("HandWriting Recognition");

		this.setSize(windowWidth, windowHeight);
		this.setLocation(windowX, windowY);
		this.setVisible(true);
		this.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
	}

	public static void main(String[] args) {
		// TODO Auto-generated method stub
		Recognition recognition = new Recognition();
	}

	@Override
	public void actionPerformed(ActionEvent e) {
		// TODO Auto-generated method stub
		if (e.getSource() == reWriteButton) {
			repaint();
		} else if (e.getSource() == recognitionButton) {
			this.contentPaneX = (int) this.getContentPane().getLocationOnScreen().getX();
			this.contentPaneY = (int) this.getContentPane().getLocationOnScreen().getY();
			ScreenShot screenShot = new ScreenShot(contentPaneX + 9, contentPaneY + 9, 320, 320, "maggie");
			screenShot.capture();
			ZoomImage zoomImage = new ZoomImage("./maggie.png", 0.1f);
			zoomImage.zoom();
			RGB2binary rgb2binary = new RGB2binary("./zoominMaggie.png");
			rgb2binary.rgb2binary();
			short[] userInput = rgb2binary.getUserInputDigit();
			
			Knn knn = new Knn();
			System.out.println("reading trainingSet...");
			knn.readTrainingSet();
			System.out.println("reading trainingSet over");
			int recognitionResult = knn.knn(userInput, 3);
			System.out.println("recognitionResult:"+ recognitionResult);
			showResult.setText("Your input is \r\n" + String.valueOf(recognitionResult));
			showResult.setVisible(true);
		} else if ( e.getSource() == testButton ){
			Knn knn = new Knn();
			double precise = knn.knnPrecise();
			
			String string = "Training Set Size is :\r\n" + knn.getTrainingSetSize() + "\r\nTest Set Size is :\r\n" + knn.getTestSetSize() + "\r\nAccury is \r\n" + String.valueOf(precise);
			showResult.setText(string);
			showResult.setVisible(true);
		}
	}
}

class Board extends JPanel implements MouseMotionListener {
	final private int boardWidth = 320;
	final private int boardHeight = 320;
	final private int boardX = 1;
	final private int boardY = 1;
	private int pencilWidth = 40;

	public void paint(Graphics graphics) {
		super.paint(graphics);
		graphics.setColor(Color.BLACK);
		graphics.draw3DRect(this.boardX - 1, this.boardY - 1, this.boardWidth + 1, this.boardHeight + 1, true);
		graphics.setColor(Color.WHITE);
		graphics.fill3DRect(this.boardX, this.boardY, this.boardWidth, this.boardHeight, true);
	}

	@Override
	public void mouseDragged(MouseEvent e) {
		// TODO Auto-generated method stub
		Graphics graphics = this.getGraphics();
		if (e.getX() > 1 && e.getX() < boardWidth - this.pencilWidth && e.getY() > 1
				&& e.getY() < boardHeight - pencilWidth)
			graphics.fillOval(e.getX(), e.getY(), pencilWidth, pencilWidth);
	}

	@Override
	public void mouseMoved(MouseEvent e) {
		// TODO Auto-generated method stub

	}
}

class ScreenShot {
	private int startX;
	private int startY;
	private int width;
	private int height;
	private String saveTo;

	public ScreenShot(int startX, int startY, int width, int height, String filename) {
		this.startX = startX;
		this.startY = startY;
		this.width = width;
		this.height = height;
		this.saveTo = ".\\" + filename + ".png";
	}

	public void capture() {
		File file = new File(saveTo);
		try {
			BufferedImage bufferedImage = (new Robot())
					.createScreenCapture(new Rectangle(startX, startY, width, height));
			ImageIO.write(bufferedImage, "png", file);
			System.out.println("capture image has finish...");
		} catch (AWTException e) {
			// TODO Auto-generated catch block
			e.printStackTrace();
		} catch (IOException e) {
			// TODO Auto-generated catch block
			e.printStackTrace();
		}
	}
}

class ZoomImage {
	private String filename;
	private float scaling;

	public ZoomImage(String filename, float scaling) {
		this.filename = filename;
		this.scaling = scaling;
	}

	public void zoom() {
		File file = new File(this.filename);
		try {
			BufferedImage bufferedImage1 = ImageIO.read(new File(filename));
			BufferedImage bufferedImage2 = new BufferedImage((int) (this.scaling * bufferedImage1.getWidth()),
					(int) (this.scaling * bufferedImage1.getHeight()), BufferedImage.TYPE_INT_BGR);
			Graphics graphics = bufferedImage2.createGraphics();
			graphics.drawImage(bufferedImage1, 0, 0, (int) (this.scaling * bufferedImage1.getWidth()),
					(int) (this.scaling * bufferedImage1.getHeight()), null);
			ImageIO.write(bufferedImage2, "png", new File(".\\zoominMaggie.png"));
			System.out.println("image has been zoomed...");
		} catch (IOException e) {
			// TODO Auto-generated catch block
			e.printStackTrace();
		}
	}
}

class RGB2binary {
	private String filename;
	private short[] userInputDigit = new short[32 * 32];

	public short[] getUserInputDigit() {
		return this.userInputDigit;
	}

	public RGB2binary(String filename) {
		this.filename = filename;
	}

	public void rgb2binary() {
		System.out.println(this.filename);
		File file = new File(this.filename);
		try {
			BufferedImage bufferedImage = ImageIO.read(file);
			int startX = bufferedImage.getMinX();
			int startY = bufferedImage.getMinY();
			int width = bufferedImage.getWidth();
			int height = bufferedImage.getHeight();
			System.out.println("x = " + startX + " y = " + startY + " width = " + width + " height = " + height);
			for (int i = startX; i < width; i++) {
				for (int j = startY; j < height; j++) {
					int pixel = bufferedImage.getRGB(j, i);
					int r = (pixel & 0xff0000) >> 16;
					int g = (pixel & 0xff00) >> 8;
					int b = (pixel & 0xff);
					float gray = r * 0.3f + g * 0.59f + b * 0.11f;
					if (gray > 128) {
						System.out.print(0 + "");
						userInputDigit[i * width + j] = 0;
					} else {
						System.out.print(1 + "");
						userInputDigit[i * width + j] = 1;
					}
				}
				System.out.println();
			}
		} catch (IOException e) {
			// TODO Auto-generated catch block
			e.printStackTrace();
		}
	}
}

class Knn {
	private int featureSize = 32 * 32;
	String trainingSetDir = "./trainingDigits";
	String testSetDir = "./testDigits1";

	private int trainingSetSize;
	private int testSetSize;
	private short[][] trainingData = null;
	private short[] trainintSetLabel = null;
	private short[][] testData = null;
	private short[] testSetLabel = null;

	public Knn() {

	}

	//读取训练集
	public void readTrainingSet() {
		File path = new File(trainingSetDir);
		File files[] = path.listFiles();
		
		System.out.println("total file number: " + files.length);
		
		this.trainingSetSize = files.length;
		trainingData = new short[trainingSetSize][32 * 32];
		trainintSetLabel = new short[trainingSetSize];
		
		int fileCount = 0;
		for (File file : files) {
			
			String[] filename = file.getName().split("_");
			trainintSetLabel[fileCount] = Short.parseShort(String.valueOf(filename[0]));
			
			int lines = 0;
			char buff[] = new char[32 + 2];		//为什么要+2:因为要读取文件末尾的换行和回车
			int count = 0;
			try {
				FileReader fileReader = new FileReader(file);
				while( -1 != (count = fileReader.read(buff)) ){
					for( int i = 0; i < 32; i++ )
						trainingData[fileCount][lines * 32 + i] = Short.parseShort(String.valueOf(buff[i]));
					lines++;
				}
				fileReader.close();
			} catch (IOException e) {
				// TODO Auto-generated catch block
				e.printStackTrace();
			}
			fileCount++;
		}
	}
	
	//读取测试集
	public void readTestSet()
	{
		File path = new File(testSetDir);
		File[] files = path.listFiles();
		
		System.out.println("total number of test file" + files.length);
		this.testSetSize = files.length;
		testData = new short[this.testSetSize][32 * 32];
		testSetLabel = new short[this.testSetSize];
		
		int fileCount = 0;
		for( File file : files )
		{
			String[] filename = file.getName().split("_");
			testSetLabel[fileCount] = Short.parseShort(String.valueOf(filename[0]));
			
			try {
				FileReader fileReader = new FileReader(file);
				int count = 0;
				int lines = 0;
				char buff[] = new char[32 + 2];
				while( -1 != (count = fileReader.read(buff)) )
				{
					for( int i = 0; i < 32; i++ )
						testData[fileCount][lines * 32 + i] = Short.parseShort(String.valueOf(buff[i]));
					lines++;
				}
				fileReader.close();
				fileCount++;
			} catch (FileNotFoundException e) {
				// TODO Auto-generated catch block
				e.printStackTrace();
			} catch (IOException e) {
				// TODO Auto-generated catch block
				e.printStackTrace();
			}
		}
	}
	
	//@feature,待判断的实例的特征向量,
	//@k,即为knn算法中的k
	//返回分类的结果
	public int knn(short[] feature, int k)
	{
		double[] distances = new double[this.trainingSetSize];

		for( int i = 0; i < trainingSetSize; i++ )
			distances[i] = calculateDistance(feature, trainingData[i]);
	
		int[] argDistance = this.arg_sort(distances);

		HashMap<Short, Integer> vote = new HashMap<>();
		for( int i = 0; i < k; i++ )
		{
			if ( null == vote.get(trainintSetLabel[argDistance[i]]) )
				vote.put(trainintSetLabel[argDistance[i]], 1);
			else
			{
				int score = vote.get(trainintSetLabel[argDistance[i]]) + 1;
				vote.put(trainintSetLabel[argDistance[i]], score);
			}
		}
		int result = 0;
		int maxVote = 0;
		for( short key : vote.keySet() )
		{
			if( maxVote < vote.get(key) )
			{
				result = key;
				maxVote = vote.get(key);
			}
		}
		return result;
	}
	
	//在测试集上计算该算法的准确性
	public double knnPrecise()
	{
		System.out.println("reading trainingSet...");
		this.readTrainingSet();
		System.out.println("reading trainingSet over");
		System.out.println("reading testSet...");
		this.readTestSet();
		System.out.println("reading testSet end");
		
		int success = 0;
		for( int i = 0; i < testSetSize; i++ )
			if( testSetLabel[i] == knn(testData[i], 3) )
				success++;
		return (double)success/testSetSize;
	}
	
	public double calculateDistance(short[] sequcence1, short[] sequence2)
	{
		int distance = 0;
		for( int i = 0; i < sequcence1.length; i++ )
			distance += (sequcence1[i] - sequence2[i]) * (sequcence1[i] - sequence2[i]);
		return Math.sqrt(distance);
	}
	
	//返回的是sequence升序的下标序列
	public int[] arg_sort(double[] sequence)
	{
		double[] sequence1 = sequence.clone();
		
		int[] indexOfSequence = new int[sequence.length];
		for( int i = 0; i < sequence1.length; i++ )
			indexOfSequence[i] = i;
		
		double minValue, tempD;
		int minIndex,tempI;
		for( int i = 0; i < sequence1.length - 1; i++ )
		{
			minValue = sequence1[i];
			minIndex = i;
			for( int j = i + 1; j < sequence1.length; j++ )
			{
				if( sequence1[j] < minValue )
				{
					minValue = sequence1[j];
					minIndex = j;
				}
			}
			if( i != minIndex )
			{
				tempD = sequence1[minIndex];
				tempI = indexOfSequence[minIndex];
				sequence1[minIndex] = sequence1[i];
				indexOfSequence[minIndex] = indexOfSequence[i];
				sequence1[i] = tempD;
				indexOfSequence[i] = tempI;
			}
		}
		return indexOfSequence;
	}
	
	
	public int getTrainingSetSize() {
		return trainingSetSize;
	}
	public int getTestSetSize() {
		return testSetSize;
	}
}



0
0
查看评论

数字识别系统

  • 2008-04-10 08:05
  • 200KB
  • 下载

[毕业设计]手写数字识别系统设计与实现

内容摘要:本文论述并设计实现了一个自由手写体数字识别系统。文中首先对待识别数字的预处理进行了介绍,包括二值化、噪声处理、图像分割、归一化、细化等图像处理方法;其次,探讨了数字字符特征向量的提取;最后采用了bp神经网络算法,并以MATLAB作为编程工具实现了具有友好的图形用户界面的自由手写体数字识别系...
  • szydwy
  • szydwy
  • 2016-03-20 20:30
  • 4955

图片数字自动识别工具

  • 2012-11-19 11:36
  • 1.13MB
  • 下载

数字识别系统源代码

  • 2006-08-14 09:27
  • 204KB
  • 下载

使用Knn算法实现手写数字识别系统(附带jpg转txt代码)

手写识别暂且可以理解为:一个jpg格式,然后转换为txt格式,这里普遍用01来替代。 先附上代码,以后可以自己构建此类数据集 jpg格式转txt格式:(代码的核心思想是如何求出对应灰度值后填充自己规定的ascii_char) from PIL import Image import a...
  • qq_33638791
  • qq_33638791
  • 2017-01-16 23:57
  • 523

机器学习实战kNN之手写识别

kNN算法算是机器学习入门级绝佳的素材。书上是这样诠释的:“存在一个样本数据集合,也称作训练样本集,并且样本集中每个数据都有标签,即我们知道样本集中每一条数据与所属分类的对应关系。输入没有标签的新数据后,将新数据的每个特征与样本集中数据对应的特征比较,算法提取样本集中特征最相似数据(最近邻)的分类标...
  • wyb_009
  • wyb_009
  • 2013-06-25 21:44
  • 6200

机器学习实战k近邻算法(kNN)应用之手写数字识别代码解读

一.背景简要说明 书中假设待识别的数字已经使用图形处理软件将其处理为32*32的黑白图像,并将图片转换为文本格式。如下图代表数字0: 每个数字的训练样本大概有200个,每个数字的测试样本大概有100个,分别放在trainingDigits和testDigits中。 ...
  • SCUT_Arucee
  • SCUT_Arucee
  • 2015-12-11 11:33
  • 3921

机器学习实战笔记——基于KNN算法的手写识别系统

利用k-近邻分类器实现手写识别系统,训练数据集大约2000个样本,每个数字大约有200个样本,每个样本保存在一个txt文件中,手写体图像本身是32X32的二值图像,如下图所示: 首先,我们需要将图像格式化处理为一个向量,把一个32X32的二进制图像矩阵通过img2vector()函数转换为1...
  • geekmanong
  • geekmanong
  • 2016-01-17 14:16
  • 2460

机器学习整理笔记——使用k-近邻算法对手写识别系统的测试

使用k-近邻算法对约会网站进行配对
  • pandafactory
  • pandafactory
  • 2016-06-06 17:07
  • 952

KNN--用于手写数字识别(机器学习入门笔记)

最近在看机器学习实战这本书,写下博客作为笔记以帮助记忆。一、K-近邻算法概述概括的说,K-近邻算法采用测量不同特征值之间的距离的方法进行分类。 它的工作原理是:存在一个样本数据集合,也称训练样本集,并且样本集中每个数据存在标签,即我们知道样本集中每一个数据与所属分类的对应关系。输入没有标签的新数据...
  • zengxyuyu
  • zengxyuyu
  • 2017-01-13 21:32
  • 2601
    个人资料
    • 访问:219606次
    • 积分:3189
    • 等级:
    • 排名:第12760名
    • 原创:95篇
    • 转载:0篇
    • 译文:0篇
    • 评论:65条
    最新评论