BP神经网络验证码识别

上完机器学习课程,需要交一个大作业。想着学校场馆预定系统经常订不到,所以做一个自动提交表单的工具,关键在于验证码识别。


图像处理代码

训练集:

验证集:

package mine.imageProcesser;

import jdk.nashorn.internal.objects.Global;
import mine.common.GlobalPara;

import java.awt.Color;
import java.awt.image.BufferedImage;
import java.io.File;
import java.io.FileOutputStream;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import javax.imageio.ImageIO;

public class ImagePreProcess {

   public static int isBlack(int colorInt) {
      Color color = new Color(colorInt);
      if (color.getRed() + color.getGreen() + color.getBlue() <= 100) {
         return 1;
      }
      return 0;
   }
   
   public static int isWhite(int colorInt) {
      Color color = new Color(colorInt);
      if (color.getRed() + color.getGreen() + color.getBlue() > 100) {
         return 1;
      }
      return 0;
   }

   public static BufferedImage removeBackgroud(String picFile)
         throws Exception {
      BufferedImage img = ImageIO.read(new File(picFile));
      int width = img.getWidth();
      int height = img.getHeight();
      for (int x = 0; x < width; ++x) {
         for (int y = 0; y < height; ++y) {
            if (isWhite(img.getRGB(x, y)) == 1) {
               img.setRGB(x, y, Color.WHITE.getRGB());
            } else {
               img.setRGB(x, y, Color.BLACK.getRGB());
            }
         }
      }
      return img;
   }

   public static List<BufferedImage> splitImage(BufferedImage img)
         throws Exception {
      List<BufferedImage> subImgs = new ArrayList<BufferedImage>();
      subImgs.add(img.getSubimage(10, 6, 8, 10));
      subImgs.add(img.getSubimage(19, 6, 8, 10));
      subImgs.add(img.getSubimage(28, 6, 8, 10));
      subImgs.add(img.getSubimage(37, 6, 8, 10));
      return subImgs;
   }

   public static Map<BufferedImage, String> loadTrainData() throws Exception {
      Map<BufferedImage, String> map = new HashMap<BufferedImage, String>();
      File dir = new File(GlobalPara.trainPath);
      File[] files = dir.listFiles();
      for (File file : files) {
         map.put(ImageIO.read(file), file.getName().charAt(0) + "");
      }
      return map;
   }

   public static String getSingleCharOcr(BufferedImage img,
         Map<BufferedImage, String> map) {
      String result = "";
      int width = img.getWidth();
      int height = img.getHeight();
      int min = width * height;
      for (BufferedImage bi : map.keySet()) {
         int count = 0;
         Label1: for (int x = 0; x < width; ++x) {
            for (int y = 0; y < height; ++y) {
               if (isWhite(img.getRGB(x, y)) != isWhite(bi.getRGB(x, y))) {
                  count++;
                  if (count >= min)
                     break Label1;
               }
            }
         }
         if (count < min) {
            min = count;
            result = map.get(bi);
         }
      }
      return result;
   }

   public static String getAllOcr(String file) throws Exception {
      BufferedImage img = removeBackgroud(file);
      List<BufferedImage> listImg = splitImage(img);
      Map<BufferedImage, String> map = loadTrainData();
      String result = "";
      for (BufferedImage bi : listImg) {
         result += getSingleCharOcr(bi, map);
      }
      ImageIO.write(img, "JPG", new File(GlobalPara.resultPath+result+".jpg"));
      return result;
   }

   /**
    * @param args
    * @throws Exception
    */
   public static void main(String[] args) throws Exception {
      for (int i = 0; i < 30; ++i) {
         String text = getAllOcr(GlobalPara.imgPath + i + ".jpg");
         System.out.println(i + ".jpg = " + text);
      }
   }
}
-------------------------------------------------------------------------

BP神经网络
package mine;

/**
 * Created by robby on 2017/6/5.
 */

import java.util.Random;

public class BP {
    //输入层
    private final double[] input;
    //隐藏层
    private final double[] hidden;
    //输出层
    private final double[] output;
    //目标
    private final double[] target;
    //隐藏层误差
    private final double[] hidDelta;
    //输出层误差
    private final double[] optDelta;

    //学习率
    private final double eta;
    //动态率
    private final double momentum;

    //输入层和隐藏层的链接权重
    private final double[][] iptHidWeights;
    //隐藏层和输出层的链接权重
    private final double[][] hidOptWeights;
    //输入层和隐藏层之前的权重
    private final double[][] iptHidPrevUptWeights;
    //隐藏层和输出层之前的权重
    private final double[][] hidOptPrevUptWeights;

    public double optErrSum = 0d;

    public double hidErrSum = 0d;

    private final Random random;

    /**
     * 带参数的构造方法
     *
     * @param inputSize
     * @param hiddenSize
     * @param outputSize
     * @param eta
     * @param momentum
     */
    public BP(int inputSize, int hiddenSize, int outputSize, double eta,
              double momentum) {

        input = new double[inputSize + 1];
        hidden = new double[hiddenSize + 1];
        output = new double[outputSize + 1];
        target = new double[outputSize + 1];

        hidDelta = new double[hiddenSize + 1];
        optDelta = new double[outputSize + 1];

        iptHidWeights = new double[inputSize + 1][hiddenSize + 1];
        hidOptWeights = new double[hiddenSize + 1][outputSize + 1];

        random = new Random(19881211);
        randomizeWeights(iptHidWeights);
        randomizeWeights(hidOptWeights);

        iptHidPrevUptWeights = new double[inputSize + 1][hiddenSize + 1];
        hidOptPrevUptWeights = new double[hiddenSize + 1][outputSize + 1];

        this.eta = eta;
        this.momentum = momentum;
    }

    private void randomizeWeights(double[][] matrix) {
        for (int i = 0, len = matrix.length; i != len; i++)
            for (int j = 0, len2 = matrix[i].length; j != len2; j++) {
                double real = random.nextDouble();
                matrix[i][j] = random.nextDouble() > 0.5 ? real : -real;
            }
    }

    /**
     * 使用默认参数的构造方法
     *
     * @param inputSize
     * @param hiddenSize
     * @param outputSize
     */
    public BP(int inputSize, int hiddenSize, int outputSize) {
        this(inputSize, hiddenSize, outputSize, 0.5, 0.9);
    }

    /**
     * 训练器
     *
     * @param trainData
     * @param target
     */
    public void train(double[] trainData, double[] target) {
        loadInput(trainData);
        loadTarget(target);
        forward();
        calculateDelta();
        adjustWeight();
    }

    /**
     * 测试器
     *
     * @param inData
     * @return
     */
    public double[] test(double[] inData) {
        if (inData.length != input.length - 1) {
            throw new IllegalArgumentException("Size Do Not Match.");
        }
        System.arraycopy(inData, 0, input, 1, inData.length);
        forward();
        return getNetworkOutput();
    }

    /**
     * 获得输出层
     *
     * @return
     */
    private double[] getNetworkOutput() {
        int len = output.length;
        double[] temp = new double[len - 1];
        for (int i = 1; i != len; i++)
            temp[i - 1] = output[i];
        return temp;
    }

    /**
     * 加载目标数据
     *
     * @param arg
     */
    private void loadTarget(double[] arg) {
        if (arg.length != target.length - 1) {
            throw new IllegalArgumentException("Size Do Not Match.");
        }
        System.arraycopy(arg, 0, target, 1, arg.length);
    }

    /**
     * 加载训练数据
     *
     * @param inData
     */
    private void loadInput(double[] inData) {
        if (inData.length != input.length - 1) {
            throw new IllegalArgumentException("Size Do Not Match.");
        }
        System.arraycopy(inData, 0, input, 1, inData.length);
    }

    /**
     * 前向计算
     *
     * @param layer0
     * @param layer1
     * @param weight
     */
    private void forward(double[] layer0, double[] layer1, double[][] weight) {
        layer0[0] = 1.0;
        for (int j = 1, len = layer1.length; j != len; ++j) {
            double sum = 0;
            for (int i = 0, len2 = layer0.length; i != len2; ++i)
                sum += weight[i][j] * layer0[i];
            layer1[j] = sigmoid(sum);
        }
    }

    /**
     * 前向计算
     */
    private void forward() {
        forward(input, hidden, iptHidWeights);
        forward(hidden, output, hidOptWeights);
    }

    /**
     * 计算输出误差
     */
    private void outputErr() {
        double errSum = 0;
        for (int idx = 1, len = optDelta.length; idx != len; ++idx) {
            double o = output[idx];
            optDelta[idx] = o * (1d - o) * (target[idx] - o);
            errSum += Math.abs(optDelta[idx]);
        }
        optErrSum = errSum;
    }

    /**
     * 计算隐藏误差
     */
    private void hiddenErr() {
        double errSum = 0;
        for (int j = 1, len = hidDelta.length; j != len; ++j) {
            double o = hidden[j];
            double sum = 0;
            for (int k = 1, len2 = optDelta.length; k != len2; ++k)
                sum += hidOptWeights[j][k] * optDelta[k];
            hidDelta[j] = o * (1d - o) * sum;
            errSum += Math.abs(hidDelta[j]);
        }
        hidErrSum = errSum;
    }

    /**
     * 计算全部层数的误差
     */
    private void calculateDelta() {
        outputErr();
        hiddenErr();
    }

    /**
     * 调节权重
     *
     * @param delta
     * @param layer
     * @param weight
     * @param prevWeight
     */
    private void adjustWeight(double[] delta, double[] layer,
                              double[][] weight, double[][] prevWeight) {

        layer[0] = 1;
        for (int i = 1, len = delta.length; i != len; ++i) {
            for (int j = 0, len2 = layer.length; j != len2; ++j) {
                double newVal = momentum * prevWeight[j][i] + eta * delta[i]
                        * layer[j];
                weight[j][i] += newVal;
                prevWeight[j][i] = newVal;
            }
        }
    }

    /**
     * 调节各层的权重
     */
    private void adjustWeight() {
        adjustWeight(optDelta, hidden, hidOptWeights, hidOptPrevUptWeights);
        adjustWeight(hidDelta, input, iptHidWeights, iptHidPrevUptWeights);
    }

    /**
     * 符号判别函数
     *
     * @param val
     * @return
     */
    private double sigmoid(double val) {
        return 1d / (1d + Math.exp(-val));
    }
}
-----------------------------------------------------------------------------

全局变量
package mine.common;

/**
 * Created by robby on 2017/6/5.
 */
public class GlobalPara {
    public static final String trainPath = "D:\\Projects\\ml\\recognizer\\data\\train\\";
    public static final String imgPath = "D:\\Projects\\ml\\recognizer\\data\\img\\";
    public static final String resultPath = "D:\\Projects\\ml\\recognizer\\data\\result\\";
}
----------------------------------------------------------------------------------

主函数类
package mine;

/**
 * Created by robby on 2017/6/5.
 */

import mine.common.GlobalPara;
import mine.imageProcesser.ImagePreProcess;

import javax.imageio.ImageIO;
import java.awt.*;
import java.awt.image.BufferedImage;
import java.io.File;
import java.io.IOException;
import java.util.*;
import java.util.List;

public class Main {

    public static double[] image2double(BufferedImage image){
        int width = image.getWidth();
        int height = image.getHeight();
        double[] result = new double[width*height];
        for (int x = 0; x < width; ++x) {
            for (int y = 0; y < height; ++y) {
                result[(x+1)*(y+1)-1] = (double)image.getRGB(x, y);
            }
        }
        return result;
    }

    public static double[][] loadTrainData(String path) throws IOException {
        double[][] result;
        File dir = new File(GlobalPara.trainPath);
        File[] files = dir.listFiles();
        result = new double[files.length][];
        for (int i=0; i<files.length; i++) {
            result[i] = image2double(ImageIO.read(files[i]));
        }
        return result;
    }

    public static double[] getReal(File file){
        double[] result = new double[10];
        int charValue = Integer.valueOf(file.getName().charAt(0)+"");
        result[charValue] = 1;
        return result;
    }

    public static double[] getReal(File file, int i){
        double[] result = new double[10];
        int charValue = Integer.valueOf(file.getName().charAt(i)+"");
        result[charValue] = 1;
        return result;
    }

    public static String getOutputValue(double[] res){
        double max = res[0];
        int index = 0;
        for(int i=0; i<res.length; i++){
            index = max > res[i] ? index : i;
        }
        return String.valueOf(index) ;
    }

    /**
     * @param args
     * @throws IOException
     */
    public static void main(String[] args) throws Exception {
        BP bp = new BP(80, 80, 10, 0.6, 0.01);
        /*for(int m=0; m<1000; m++){
            File dir = new File(GlobalPara.trainPath);
            File[] files = dir.listFiles();
            for (int i=0; i<files.length; i++) {
                double[] trainData = image2double(ImageIO.read(files[i]));
                double[] real = getReal(files[i]);
                bp.train(trainData, real);
            }
        }*/

        for(int i=0; i<500; i++){
            File trainDir = new File(GlobalPara.trainPath);
            File[] trainFiles = trainDir.listFiles();
            for(File trainFile : trainFiles) {
                BufferedImage img = ImagePreProcess.removeBackgroud(trainFile.getAbsolutePath());
                List<BufferedImage> listImg = ImagePreProcess.splitImage(img);
                for (int m = 0; m < listImg.size(); m++) {
                    BufferedImage image = listImg.get(m);
                    double[] trainData = image2double(image);
                    double[] real = getReal(trainFile, m);

                    bp.train(trainData, real);
                }
            }
        }

        System.out.println("训练结束");

        File testDir = new File(GlobalPara.imgPath);
        File[] testFiles = testDir.listFiles();
        int count = 0;
        for(File file : testFiles){
            StringBuilder sb = new StringBuilder();
            BufferedImage img = ImagePreProcess.removeBackgroud(file.getAbsolutePath());
            List<BufferedImage> listImg = ImagePreProcess.splitImage(img);
            for(int m=0; m<listImg.size(); m++){
                BufferedImage image = listImg.get(m);
                double[] testData = image2double(image);
                double[] result = bp.test(testData);
                double max = -Integer.MIN_VALUE;
                int idx = -1;
                for (int i = 0; i != result.length; i++) {
                    if (result[i] > max) {
                        max = result[i];
                        idx = i;
                    }
                }

                if(Integer.valueOf(file.getName().charAt(m)+"").equals(idx)){
                    count++;
                }
                //System.out.println(" Result: " + Arrays.toString( result));
                sb.append(idx);
            }
            System.out.print("Input: " + file.getName());
            System.out.println(" Output: " + sb.toString());
        }

        System.out.println("准确率:"+(double)count/(testFiles.length*4));
    }
}



 

  • 1
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值