上完机器学习课程,需要交一个大作业。想着学校场馆预定系统经常订不到,所以做一个自动提交表单的工具,关键在于验证码识别。
图像处理代码
训练集:
验证集:
package mine.imageProcesser; import jdk.nashorn.internal.objects.Global; import mine.common.GlobalPara; import java.awt.Color; import java.awt.image.BufferedImage; import java.io.File; import java.io.FileOutputStream; import java.io.InputStream; import java.io.OutputStream; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; import javax.imageio.ImageIO; public class ImagePreProcess { public static int isBlack(int colorInt) { Color color = new Color(colorInt); if (color.getRed() + color.getGreen() + color.getBlue() <= 100) { return 1; } return 0; } public static int isWhite(int colorInt) { Color color = new Color(colorInt); if (color.getRed() + color.getGreen() + color.getBlue() > 100) { return 1; } return 0; } public static BufferedImage removeBackgroud(String picFile) throws Exception { BufferedImage img = ImageIO.read(new File(picFile)); int width = img.getWidth(); int height = img.getHeight(); for (int x = 0; x < width; ++x) { for (int y = 0; y < height; ++y) { if (isWhite(img.getRGB(x, y)) == 1) { img.setRGB(x, y, Color.WHITE.getRGB()); } else { img.setRGB(x, y, Color.BLACK.getRGB()); } } } return img; } public static List<BufferedImage> splitImage(BufferedImage img) throws Exception { List<BufferedImage> subImgs = new ArrayList<BufferedImage>(); subImgs.add(img.getSubimage(10, 6, 8, 10)); subImgs.add(img.getSubimage(19, 6, 8, 10)); subImgs.add(img.getSubimage(28, 6, 8, 10)); subImgs.add(img.getSubimage(37, 6, 8, 10)); return subImgs; } public static Map<BufferedImage, String> loadTrainData() throws Exception { Map<BufferedImage, String> map = new HashMap<BufferedImage, String>(); File dir = new File(GlobalPara.trainPath); File[] files = dir.listFiles(); for (File file : files) { map.put(ImageIO.read(file), file.getName().charAt(0) + ""); } return map; } public static String getSingleCharOcr(BufferedImage img, Map<BufferedImage, String> map) { String result = ""; int width = img.getWidth(); int height = img.getHeight(); int min = width * height; for (BufferedImage bi : map.keySet()) { int count = 0; Label1: for (int x = 0; x < width; ++x) { for (int y = 0; y < height; ++y) { if (isWhite(img.getRGB(x, y)) != isWhite(bi.getRGB(x, y))) { count++; if (count >= min) break Label1; } } } if (count < min) { min = count; result = map.get(bi); } } return result; } public static String getAllOcr(String file) throws Exception { BufferedImage img = removeBackgroud(file); List<BufferedImage> listImg = splitImage(img); Map<BufferedImage, String> map = loadTrainData(); String result = ""; for (BufferedImage bi : listImg) { result += getSingleCharOcr(bi, map); } ImageIO.write(img, "JPG", new File(GlobalPara.resultPath+result+".jpg")); return result; } /** * @param args * @throws Exception */ public static void main(String[] args) throws Exception { for (int i = 0; i < 30; ++i) { String text = getAllOcr(GlobalPara.imgPath + i + ".jpg"); System.out.println(i + ".jpg = " + text); } } } ------------------------------------------------------------------------- BP神经网络package mine; /** * Created by robby on 2017/6/5. */ import java.util.Random; public class BP { //输入层 private final double[] input; //隐藏层 private final double[] hidden; //输出层 private final double[] output; //目标 private final double[] target; //隐藏层误差 private final double[] hidDelta; //输出层误差 private final double[] optDelta; //学习率 private final double eta; //动态率 private final double momentum; //输入层和隐藏层的链接权重 private final double[][] iptHidWeights; //隐藏层和输出层的链接权重 private final double[][] hidOptWeights; //输入层和隐藏层之前的权重 private final double[][] iptHidPrevUptWeights; //隐藏层和输出层之前的权重 private final double[][] hidOptPrevUptWeights; public double optErrSum = 0d; public double hidErrSum = 0d; private final Random random; /** * 带参数的构造方法 * * @param inputSize * @param hiddenSize * @param outputSize * @param eta * @param momentum */ public BP(int inputSize, int hiddenSize, int outputSize, double eta, double momentum) { input = new double[inputSize + 1]; hidden = new double[hiddenSize + 1]; output = new double[outputSize + 1]; target = new double[outputSize + 1]; hidDelta = new double[hiddenSize + 1]; optDelta = new double[outputSize + 1]; iptHidWeights = new double[inputSize + 1][hiddenSize + 1]; hidOptWeights = new double[hiddenSize + 1][outputSize + 1]; random = new Random(19881211); randomizeWeights(iptHidWeights); randomizeWeights(hidOptWeights); iptHidPrevUptWeights = new double[inputSize + 1][hiddenSize + 1]; hidOptPrevUptWeights = new double[hiddenSize + 1][outputSize + 1]; this.eta = eta; this.momentum = momentum; } private void randomizeWeights(double[][] matrix) { for (int i = 0, len = matrix.length; i != len; i++) for (int j = 0, len2 = matrix[i].length; j != len2; j++) { double real = random.nextDouble(); matrix[i][j] = random.nextDouble() > 0.5 ? real : -real; } } /** * 使用默认参数的构造方法 * * @param inputSize * @param hiddenSize * @param outputSize */ public BP(int inputSize, int hiddenSize, int outputSize) { this(inputSize, hiddenSize, outputSize, 0.5, 0.9); } /** * 训练器 * * @param trainData * @param target */ public void train(double[] trainData, double[] target) { loadInput(trainData); loadTarget(target); forward(); calculateDelta(); adjustWeight(); } /** * 测试器 * * @param inData * @return */ public double[] test(double[] inData) { if (inData.length != input.length - 1) { throw new IllegalArgumentException("Size Do Not Match."); } System.arraycopy(inData, 0, input, 1, inData.length); forward(); return getNetworkOutput(); } /** * 获得输出层 * * @return */ private double[] getNetworkOutput() { int len = output.length; double[] temp = new double[len - 1]; for (int i = 1; i != len; i++) temp[i - 1] = output[i]; return temp; } /** * 加载目标数据 * * @param arg */ private void loadTarget(double[] arg) { if (arg.length != target.length - 1) { throw new IllegalArgumentException("Size Do Not Match."); } System.arraycopy(arg, 0, target, 1, arg.length); } /** * 加载训练数据 * * @param inData */ private void loadInput(double[] inData) { if (inData.length != input.length - 1) { throw new IllegalArgumentException("Size Do Not Match."); } System.arraycopy(inData, 0, input, 1, inData.length); } /** * 前向计算 * * @param layer0 * @param layer1 * @param weight */ private void forward(double[] layer0, double[] layer1, double[][] weight) { layer0[0] = 1.0; for (int j = 1, len = layer1.length; j != len; ++j) { double sum = 0; for (int i = 0, len2 = layer0.length; i != len2; ++i) sum += weight[i][j] * layer0[i]; layer1[j] = sigmoid(sum); } } /** * 前向计算 */ private void forward() { forward(input, hidden, iptHidWeights); forward(hidden, output, hidOptWeights); } /** * 计算输出误差 */ private void outputErr() { double errSum = 0; for (int idx = 1, len = optDelta.length; idx != len; ++idx) { double o = output[idx]; optDelta[idx] = o * (1d - o) * (target[idx] - o); errSum += Math.abs(optDelta[idx]); } optErrSum = errSum; } /** * 计算隐藏误差 */ private void hiddenErr() { double errSum = 0; for (int j = 1, len = hidDelta.length; j != len; ++j) { double o = hidden[j]; double sum = 0; for (int k = 1, len2 = optDelta.length; k != len2; ++k) sum += hidOptWeights[j][k] * optDelta[k]; hidDelta[j] = o * (1d - o) * sum; errSum += Math.abs(hidDelta[j]); } hidErrSum = errSum; } /** * 计算全部层数的误差 */ private void calculateDelta() { outputErr(); hiddenErr(); } /** * 调节权重 * * @param delta * @param layer * @param weight * @param prevWeight */ private void adjustWeight(double[] delta, double[] layer, double[][] weight, double[][] prevWeight) { layer[0] = 1; for (int i = 1, len = delta.length; i != len; ++i) { for (int j = 0, len2 = layer.length; j != len2; ++j) { double newVal = momentum * prevWeight[j][i] + eta * delta[i] * layer[j]; weight[j][i] += newVal; prevWeight[j][i] = newVal; } } } /** * 调节各层的权重 */ private void adjustWeight() { adjustWeight(optDelta, hidden, hidOptWeights, hidOptPrevUptWeights); adjustWeight(hidDelta, input, iptHidWeights, iptHidPrevUptWeights); } /** * 符号判别函数 * * @param val * @return */ private double sigmoid(double val) { return 1d / (1d + Math.exp(-val)); } }-----------------------------------------------------------------------------
全局变量package mine.common; /** * Created by robby on 2017/6/5. */ public class GlobalPara { public static final String trainPath = "D:\\Projects\\ml\\recognizer\\data\\train\\"; public static final String imgPath = "D:\\Projects\\ml\\recognizer\\data\\img\\"; public static final String resultPath = "D:\\Projects\\ml\\recognizer\\data\\result\\"; }----------------------------------------------------------------------------------
主函数类package mine; /** * Created by robby on 2017/6/5. */ import mine.common.GlobalPara; import mine.imageProcesser.ImagePreProcess; import javax.imageio.ImageIO; import java.awt.*; import java.awt.image.BufferedImage; import java.io.File; import java.io.IOException; import java.util.*; import java.util.List; public class Main { public static double[] image2double(BufferedImage image){ int width = image.getWidth(); int height = image.getHeight(); double[] result = new double[width*height]; for (int x = 0; x < width; ++x) { for (int y = 0; y < height; ++y) { result[(x+1)*(y+1)-1] = (double)image.getRGB(x, y); } } return result; } public static double[][] loadTrainData(String path) throws IOException { double[][] result; File dir = new File(GlobalPara.trainPath); File[] files = dir.listFiles(); result = new double[files.length][]; for (int i=0; i<files.length; i++) { result[i] = image2double(ImageIO.read(files[i])); } return result; } public static double[] getReal(File file){ double[] result = new double[10]; int charValue = Integer.valueOf(file.getName().charAt(0)+""); result[charValue] = 1; return result; } public static double[] getReal(File file, int i){ double[] result = new double[10]; int charValue = Integer.valueOf(file.getName().charAt(i)+""); result[charValue] = 1; return result; } public static String getOutputValue(double[] res){ double max = res[0]; int index = 0; for(int i=0; i<res.length; i++){ index = max > res[i] ? index : i; } return String.valueOf(index) ; } /** * @param args * @throws IOException */ public static void main(String[] args) throws Exception { BP bp = new BP(80, 80, 10, 0.6, 0.01); /*for(int m=0; m<1000; m++){ File dir = new File(GlobalPara.trainPath); File[] files = dir.listFiles(); for (int i=0; i<files.length; i++) { double[] trainData = image2double(ImageIO.read(files[i])); double[] real = getReal(files[i]); bp.train(trainData, real); } }*/ for(int i=0; i<500; i++){ File trainDir = new File(GlobalPara.trainPath); File[] trainFiles = trainDir.listFiles(); for(File trainFile : trainFiles) { BufferedImage img = ImagePreProcess.removeBackgroud(trainFile.getAbsolutePath()); List<BufferedImage> listImg = ImagePreProcess.splitImage(img); for (int m = 0; m < listImg.size(); m++) { BufferedImage image = listImg.get(m); double[] trainData = image2double(image); double[] real = getReal(trainFile, m); bp.train(trainData, real); } } } System.out.println("训练结束"); File testDir = new File(GlobalPara.imgPath); File[] testFiles = testDir.listFiles(); int count = 0; for(File file : testFiles){ StringBuilder sb = new StringBuilder(); BufferedImage img = ImagePreProcess.removeBackgroud(file.getAbsolutePath()); List<BufferedImage> listImg = ImagePreProcess.splitImage(img); for(int m=0; m<listImg.size(); m++){ BufferedImage image = listImg.get(m); double[] testData = image2double(image); double[] result = bp.test(testData); double max = -Integer.MIN_VALUE; int idx = -1; for (int i = 0; i != result.length; i++) { if (result[i] > max) { max = result[i]; idx = i; } } if(Integer.valueOf(file.getName().charAt(m)+"").equals(idx)){ count++; } //System.out.println(" Result: " + Arrays.toString( result)); sb.append(idx); } System.out.print("Input: " + file.getName()); System.out.println(" Output: " + sb.toString()); } System.out.println("准确率:"+(double)count/(testFiles.length*4)); } }