车牌识别算法

2023-7-19 日更新:

添加了安卓推理测试。使用 onnxruntime 和 ncnn 部署都差不多。​​​​​​​

2023-4-17 日更新: 在 yolov5上面添加了一个 key point 检测出4个车牌关键点,结果投影变换后再使用 crnn 进行字符识别。参考代码:
https://github.com/we0091234/Chinese_license_plate_detection_recognition.git

原理比较简单,效果见上图,主要是投影变换,下面是java推理代码。


import ai.onnxruntime.*;
import org.opencv.core.*;
import org.opencv.core.Point;
import org.opencv.imgcodecs.Imgcodecs;
import org.opencv.imgproc.Imgproc;

import javax.imageio.ImageIO;
import javax.swing.*;
import java.awt.*;
import java.awt.image.BufferedImage;
import java.io.ByteArrayInputStream;
import java.io.File;
import java.nio.FloatBuffer;
import java.util.*;
import java.util.List;

/**
*   @desc : 车牌检测 + 车牌字符/颜色识别
*   @auth : tyf
*   @date : 2023-04-26  18:06:14
*/
public class yolov5_car_plate {

    // 模型1
    public static OrtEnvironment env1;
    public static OrtSession session1;

    // 模型2
    public static OrtEnvironment env2;
    public static OrtSession session2;

    // 记录一个图片的信息
    public static class ImageObj{
        // 图片模型尺寸用于推理
        Mat src;
        // 图片原始尺寸用于绘图
        Mat background;
        // 过滤后的边框信息
        List<float[]> data;
        // 投影变换后的车牌矩阵
        List<Mat> platesMat = new ArrayList<>();
        // 投影变换后的车牌
        List<String> platesStr = new ArrayList<>();
        // 车牌的颜色
        List<Color> platesColor = new ArrayList<>();
        // 颜色
        Scalar color1 = new Scalar(0, 0, 255);
        Scalar color2 = new Scalar(0, 255, 0);
        // 投影变换后车牌的宽高,也就是第二个模型的输入尺寸
        int plateWidth = 168;
        int plateHeight = 48;
        // 原始图片尺寸,也就是第一个模型的输入尺寸
        int picWidth = 640;
        int picHeight = 640;
        // 车牌类别
        char[] plateChar = new char[]{
                '#','京','沪','津','渝','冀','晋','蒙','辽','吉',
                '黑','苏','浙','皖','闽','赣','鲁','豫','鄂','湘',
                '粤','桂','琼','川','贵','云','藏','陕','甘','青',
                '宁','新','学','警','港','澳','挂','使','领','民',
                '航','危','0','1','2','3','4','5','6','7',
                '8','9','A','B','C','D','E','F','G','H',
                'J','K','L','M','N','P','Q','R','S','T',
                'U','V','W','X','Y','Z','险','品',
        };
        // 车牌颜色类别 color=['黑色','蓝色','绿色','白色','黄色']
        Color [] plateScalar = new Color []{
                Color.BLACK,
                Color.BLUE,
                Color.GREEN,
                Color.WHITE,
                Color.YELLOW
        };
        // 宽高缩放比
        float wScale;
        float hScale;
        public ImageObj(String img) {
            // 原始图像
            this.background = readImg(img);
            // 缩放过后的图像
            this.src = resizeWithoutPadding(this.background,this.picWidth,this.picHeight);
            // 保存缩放比
            this.wScale = Float.valueOf(src.width())/ Float.valueOf(background.width());
            this.hScale = Float.valueOf(src.height())/Float.valueOf(background.height());
        }
        public void setDataAndFilter(float[][] output){

            // xywh  objscore   class1 class2  x1y1 x2y2 x3y3 x4y4

            float confThreshold = 0.75f;
            float nmsThreshold = 0.45f; // 车牌识别省略nms

            List<float[]> temp = new ArrayList<>();

            // 置信度过滤
            for(int i=0;i<output.length;i++){
                float[] obj = output[i];
                float x = obj[0];
                float y = obj[1];
                float w = obj[2];
                float h = obj[3];
                float score = obj[4];
                float x1 = obj[5];
                float y1 = obj[6];
                float x2 = obj[7];
                float y2 = obj[8];
                float x3 = obj[9];
                float y3 = obj[10];
                float x4 = obj[11];
                float y4 = obj[12];
                float class1 = obj[13];
                float class2 = obj[14];
                if(score>=confThreshold){
                    // 边框坐标
                    float[] xyxy = xywh2xyxy(new float[]{x,y,w,h},this.picWidth,this.picHeight);
                    // 类别1或者2
                    float clazz = class1>class2?1:2;
                    // 类别概率
                    float clazzScore =  class1>class2?class1:class2;
                    // 关键点坐标
                    temp.add(new float[]{
                            xyxy[0], xyxy[1], xyxy[2], xyxy[3], x1, y1, x2, y2, x3, y3, x4, y4,clazz,clazzScore
                    });
                }
            }

            // 交并比过滤
            // 先按照概率排序
            temp.sort((o1, o2) -> Float.compare(o2[13],o1[13]));

            // 保存最终的过滤结果
            List<float[]> out = new ArrayList<>();
            while (!temp.isEmpty()){
                float[] max = temp.get(0);
                out.add(max);
                Iterator<float[]> it = temp.iterator();
                while (it.hasNext()) {
                    float[] obj = it.next();
                    // 交并比
                    double iou = calculateIoU(
                            new float[]{max[0],max[1],max[2],max[3]},
                            new float[]{obj[0],obj[1],obj[2],obj[3]}
                    );
                    if (iou > nmsThreshold) {
                        it.remove();
                    }
                }
            }
            // 保存最终的边框
            this.data = out;
        }

        // 对所有车牌关键点进行透视变换,拉成一个矩形
        public void transform(){

            // 首先对每个车牌目标进行关键点透视变换

            this.data.stream().forEach(n->{

                float key_point_x1 = n[4];
                float key_point_y1 = n[5];
                float key_point_x2 = n[6];
                float key_point_y2 = n[7];
                float key_point_x3 = n[8];
                float key_point_y3 = n[9];
                float key_point_x4 = n[10];
                float key_point_y4 = n[11];

                Point[] srcPoints = new Point[4];
                Point p1 = new Point(Float.valueOf(key_point_x1).intValue(), Float.valueOf(key_point_y1).intValue());
                Point p2 = new Point(Float.valueOf(key_point_x2).intValue(), Float.valueOf(key_point_y2).intValue());
                Point p3 = new Point(Float.valueOf(key_point_x3).intValue(), Float.valueOf(key_point_y3).intValue());
                Point p4 = new Point(Float.valueOf(key_point_x4).intValue(), Float.valueOf(key_point_y4).intValue());
                srcPoints[0] = p1;
                srcPoints[1] = p2;
                srcPoints[2] = p3;
                srcPoints[3] = p4;

                // 定义透视变换后的目标矩形的四个角点,指定车牌的宽和高
                Point[] dstPoints = new Point[4];
                dstPoints[0] = new Point(0, 0);
                dstPoints[1] = new Point(plateWidth, 0);
                dstPoints[2] = new Point(plateWidth, plateHeight);
                dstPoints[3] = new Point(0, plateHeight);

                // 计算透视变换矩阵
                MatOfPoint2f in1 = new MatOfPoint2f(srcPoints);
                MatOfPoint2f in2 = new MatOfPoint2f(dstPoints);
                Mat M = Imgproc.getPerspectiveTransform(in1, in2);

                // 进行透视变换
                Mat warped = new Mat();
                Imgproc.warpPerspective(src, warped, M, new Size(plateWidth, plateHeight));

                // 保存透视变换得到的车牌
                platesMat.add(warped);

            });


        }

        public void drawBox(){

            // 在原始图片尺寸上绘制,需要坐标转换

            // 遍历每个车牌框
            for(int i=0; i<this.data.size() ; i++ ){

                float[] n = data.get(i);

                // 位置信息
                float x1 = n[0] / wScale;
                float y1 = n[1] / hScale;
                float x2 = n[2] / wScale;
                float y2 = n[3] / hScale;
                float key_point_x1 = n[4] / wScale;
                float key_point_y1 = n[5] / hScale;
                float key_point_x2 = n[6] / wScale;
                float key_point_y2 = n[7] / hScale;
                float key_point_x3 = n[8] / wScale;
                float key_point_y3 = n[9] / hScale;
                float key_point_x4 = n[10] / wScale;
                float key_point_y4 = n[11] / hScale;
                float clazz = n[12];
                float clazzScore = n[13];

                // 画边框
                Imgproc.rectangle(
                        background,
                        new Point(Float.valueOf(x1).intValue(), Float.valueOf(y1).intValue()),
                        new Point(Float.valueOf(x2).intValue(), Float.valueOf(y2).intValue()),
                        color1,
                        2);
                // 画关键点四个
                Imgproc.circle(
                        background,
                        new Point(Float.valueOf(key_point_x1).intValue(), Float.valueOf(key_point_y1).intValue()),
                        3, // 半径
                        color2,
                        2);
                Imgproc.circle(
                        background,
                        new Point(Float.valueOf(key_point_x2).intValue(), Float.valueOf(key_point_y2).intValue()),
                        3, // 半径
                        color2,
                        2);
                Imgproc.circle(
                        background,
                        new Point(Float.valueOf(key_point_x3).intValue(), Float.valueOf(key_point_y3).intValue()),
                        3, // 半径
                        color2,
                        2);
                Imgproc.circle(
                        background,
                        new Point(Float.valueOf(key_point_x4).intValue(), Float.valueOf(key_point_y4).intValue()),
                        3, // 半径
                        color2,
                        2);

                // 获取车牌
                String number = platesStr.get(i);
            }
        }
    }

    // 环境初始化
    public static void init1(String weight) throws Exception{
        // opencv 库
        System.loadLibrary(Core.NATIVE_LIBRARY_NAME);

        env1 = OrtEnvironment.getEnvironment();
        session1 = env1.createSession(weight, new OrtSession.SessionOptions());
    }


    // 环境初始化
    public static void init2(String weight) throws Exception{
        // opencv 库
        System.loadLibrary(Core.NATIVE_LIBRARY_NAME);

        env2 = OrtEnvironment.getEnvironment();
        session2 = env2.createSession(weight, new OrtSession.SessionOptions());

    }

    // Mat 转 BufferedImage
    public static BufferedImage mat2BufferedImage(Mat mat){
        BufferedImage bufferedImage = null;
        try {
            // 将Mat对象转换为字节数组
            MatOfByte matOfByte = new MatOfByte();
            Imgcodecs.imencode(".jpg", mat, matOfByte);
            // 创建Java的ByteArrayInputStream对象
            byte[] byteArray = matOfByte.toArray();
            ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(byteArray);
            // 使用ImageIO读取ByteArrayInputStream并将其转换为BufferedImage对象
            bufferedImage = ImageIO.read(byteArrayInputStream);
        } catch (Exception e) {
            e.printStackTrace();
        }
        return bufferedImage;
    }


    public static float[] xywh2xyxy(float[] bbox,float maxWidth,float maxHeight) {
        // 中心点坐标
        float x = bbox[0];
        float y = bbox[1];
        float w = bbox[2];
        float h = bbox[3];
        // 计算
        float x1 = x - w * 0.5f;
        float y1 = y - h * 0.5f;
        float x2 = x + w * 0.5f;
        float y2 = y + h * 0.5f;
        // 限制在图片区域内
        return new float[]{
                x1 < 0 ? 0 : x1,
                y1 < 0 ? 0 : y1,
                x2 > maxWidth ? maxWidth:x2,
                y2 > maxHeight? maxHeight:y2};
    }

    public static Mat readImg(String path){
        Mat img = Imgcodecs.imread(path);
        return img;
    }


    public static float[] whc2cwh(float[] src) {
        float[] chw = new float[src.length];
        int j = 0;
        for (int ch = 0; ch < 3; ++ch) {
            for (int i = ch; i < src.length; i += 3) {
                chw[j] = src[i];
                j++;
            }
        }
        return chw;
    }

    public static OnnxTensor transferTensor(Mat dst,int channels,int netWidth,int netHeight){

        // BGR -> RGB
        Imgproc.cvtColor(dst, dst, Imgproc.COLOR_BGR2RGB);

        //  归一化 0-255 转 0-1
        dst.convertTo(dst, CvType.CV_32FC1, 1. / 255);

        // 初始化一个输入数组 channels * netWidth * netHeight
        float[] whc = new float[ Long.valueOf(channels).intValue() * Long.valueOf(netWidth).intValue() * Long.valueOf(netHeight).intValue() ];
        dst.get(0, 0, whc);

        // 得到最终的图片转 float 数组
        float[] chw = whc2cwh(whc);

        // 创建 onnxruntime 需要的 tensor
        // 传入输入的图片 float 数组并指定数组shape
        OnnxTensor tensor = null;
        try {
            tensor = OnnxTensor.createTensor(env1, FloatBuffer.wrap(chw), new long[]{1,channels,netHeight,netWidth});
        }
        catch (Exception e){
            e.printStackTrace();
            System.exit(0);
        }
        return tensor;
    }

    public static OnnxTensor transferTensor2(Mat dst,int channels,int netWidth,int netHeight){

        // BGR -> RGB
        Imgproc.cvtColor(dst, dst, Imgproc.COLOR_BGR2RGB);

        double[] meanValue = {0.588, 0.588, 0.588};
        double[] stdValue = {0.193, 0.193, 0.193};

        // Convert image to float and normalize using mean and standard deviation values
        dst.convertTo(dst, CvType.CV_32FC3, 1.0 / 255.0);
        Core.subtract(dst, new Scalar(meanValue), dst);
        Core.divide(dst, new Scalar(stdValue), dst);

        // 初始化一个输入数组 channels * netWidth * netHeight
        float[] whc = new float[ Long.valueOf(channels).intValue() * Long.valueOf(netWidth).intValue() * Long.valueOf(netHeight).intValue() ];
        dst.get(0, 0, whc);

        // 得到最终的图片转 float 数组
        float[] chw = whc2cwh(whc);

        // 创建 onnxruntime 需要的 tensor
        // 传入输入的图片 float 数组并指定数组shape
        OnnxTensor tensor = null;
        try {
            tensor = OnnxTensor.createTensor(env1, FloatBuffer.wrap(chw), new long[]{1,channels,netHeight,netWidth});
        }
        catch (Exception e){
            e.printStackTrace();
            System.exit(0);
        }
        return tensor;
    }

    // 计算两个框的交并比
    private static double calculateIoU(float[] box1, float[] box2) {

        //  getXYXY() 返回 xmin-0 ymin-1 xmax-2 ymax-3

        double x1 = Math.max(box1[0], box2[0]);
        double y1 = Math.max(box1[1], box2[1]);
        double x2 = Math.min(box1[2], box2[2]);
        double y2 = Math.min(box1[3], box2[3]);
        double intersectionArea = Math.max(0, x2 - x1 + 1) * Math.max(0, y2 - y1 + 1);
        double box1Area = (box1[2] - box1[0] + 1) * (box1[3] - box1[1] + 1);
        double box2Area = (box2[2] - box2[0] + 1) * (box2[3] - box2[1] + 1);
        double unionArea = box1Area + box2Area - intersectionArea;
        return intersectionArea / unionArea;
    }

    public static int getMaxIndex(float[] array) {
        int maxIndex = 0;
        float maxVal = array[0];
        for (int i = 1; i < array.length; i++) {
            if (array[i] > maxVal) {
                maxVal = array[i];
                maxIndex = i;
            }
        }
        return maxIndex;
    }

    // 将一个 src_mat 修改尺寸后存储到 dst_mat 中
    public static Mat resizeWithoutPadding(Mat src, int netWidth, int netHeight) {
        // 调整图像大小
        Mat resizedImage = new Mat();
        Size size = new Size(netWidth, netHeight);
        Imgproc.resize(src, resizedImage, size, 0, 0, Imgproc.INTER_AREA);
        return resizedImage;
    }


    // 车牌检测,以及4个关键点
    public static void doDetect(ImageObj imageObj) throws Exception{

        // 输入矩阵
        Mat in = imageObj.src.clone();
        // 转为tensor
        OnnxTensor tensor = transferTensor(in,3,imageObj.picWidth,imageObj.picHeight);
        // 推理
        OrtSession.Result res = session1.run(Collections.singletonMap("input", tensor));
        // 解析 output -> [1, 25200, 15] -> FLOAT
        float[][] data = ((float[][][])(res.get(0)).getValue())[0];
        // 根据置信度、交并比过滤
        imageObj.setDataAndFilter(data);


    }


    // 识别车牌
    public static void doRecect(ImageObj imageObj){

        // 先将关键点透视变换为矩形方便识别,目标尺寸就是第二个模型的输入 168*48
        imageObj.transform();

        // 第二个模型是crnn 输入投影变换后的车牌图片即可
        imageObj.platesMat.stream().forEach(plate->{

            try {
                OnnxTensor tensor = transferTensor2(plate.clone(),3,plate.width(),plate.height());
                OrtSession.Result res = session2.run(Collections.singletonMap("images", tensor));
                float[][] data1 = ((float[][][])(res.get(0)).getValue())[0];
                // 遍历每个字符
                char last = '-';
                List<Character> chars = new ArrayList<>();
                for(int i=0;i<data1.length;i++){
                    int maxIndex = getMaxIndex(data1[i]);
                    char maxName = imageObj.plateChar[maxIndex];
                    if( maxIndex!=0 && maxName!=last ){
                        chars.add(maxName);
                    }
                    last = maxName;
                }

                StringBuffer car = new StringBuffer();
                chars.stream().forEach(n->{
                    car.append(n);
                });
                imageObj.platesStr.add(car.toString());


                // 5 代表五个颜色
                float[] data2 = ((float[][])(res.get(1)).getValue())[0];
                int maxIndex = getMaxIndex(data2);
                Color color = imageObj.plateScalar[maxIndex];// 从类别下表中查找
                imageObj.platesColor.add(color);
            }
            catch (Exception e){
                e.printStackTrace();
            }

        });

    }


    // 弹窗显示所有信息
    public static void showJpanel(ImageObj img){

        JFrame frame = new JFrame("Car");

        // 一行两列
        JPanel parent = new JPanel();

        // 显示图片
        JPanel p1 = new JPanel();
        p1.add(new JLabel(new ImageIcon(mat2BufferedImage(img.background))));

        // 显示车牌子图片
        JPanel p2 = new JPanel(new FlowLayout(FlowLayout.LEFT, 20, 20));
        JPanel sub = new JPanel(new GridLayout(img.platesMat.size()+1, 1, 0, 5));
//        sub.setLayout(new BoxLayout(sub, BoxLayout.Y_AXIS));
        JPanel title = new JPanel(new GridLayout(1,3,10,10));
        JLabel label1 = new JLabel("投影变换");
        label1.setHorizontalAlignment(JLabel.CENTER);
        title.add(label1);
        JLabel label2 = new JLabel("车牌号");
        label2.setHorizontalAlignment(JLabel.CENTER);
        title.add(label2);
        JLabel label3 = new JLabel("颜色");
        label3.setHorizontalAlignment(JLabel.CENTER);
        title.add(label3);
        sub.add(title);
        for(int i=0;i<img.platesMat.size();i++){
            // 每个车牌占一行
            JPanel line = new JPanel(new GridLayout(1,3,10,10));
            // 车牌图片
            JLabel jLabel1 = new JLabel(new ImageIcon(mat2BufferedImage(img.platesMat.get(i))));
            // 车牌号
            JLabel jLabel2 = new JLabel(img.platesStr.get(i));
            // 车牌颜色
            JLabel jLabel3 = new JLabel("█");
            jLabel3.setForeground(img.platesColor.get(i));
            // 居中
            jLabel1.setHorizontalAlignment(JLabel.CENTER);
            jLabel2.setHorizontalAlignment(JLabel.CENTER);
            jLabel3.setHorizontalAlignment(JLabel.CENTER);
            line.add(jLabel1);
            line.add(jLabel2);
            line.add(jLabel3);
            sub.add(line);
        }
        p2.add(sub);

        parent.add(p1);
        parent.add(p2);

        frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
        frame.getContentPane().add(parent);
        frame.pack();
        frame.setVisible(true);

    }


    public static void main(String[] args) throws Exception{


        // 模型初始化 车牌检测、车牌识别
        init1(new File("").getCanonicalPath()+"\\src\\main\\resources\\deeplearning\\yolov5_car_plate\\plate_detect.onnx");
        init2(new File("").getCanonicalPath()+"\\src\\main\\resources\\deeplearning\\yolov5_car_plate\\plate_rec_color.onnx");

        // 原始图片
        ImageObj img = new ImageObj(new File("").getCanonicalPath()+"\\src\\main\\resources\\deeplearning\\yolov5_car_plate\\car.png");

        // 车牌区域检测
        doDetect(img);

        // 车牌识别
        doRecect(img);

        // 原图绘制边框
        img.drawBox();

        // 弹窗显示
        showJpanel(img);

    }

}


2019-5-22 更新:

一般来说摄像头的头端的车牌识别用opencv纯图像处理可以达到300-400毫秒一张已经是很成熟商用的方案了。

其他的识别方案比如地感+云识别也是500毫秒左右一张。

尝试过用深度学习的方法来做车牌检测+车牌识别,比较慢。

开源的easypr用的svm等比较古老的方法,尝试了一把识别一张在400毫秒左右。直接复用了easypr的训练结果,效果就是这样子:

先纯图像处理方法把车牌区域检测出来,然后第二部确定该区域是否有车牌,本质是个二分类问题。javacv代码如下:

package com.ist.EasyPr;

import org.bytedeco.javacpp.opencv_core;
import org.bytedeco.javacpp.opencv_core.Size;
import org.bytedeco.javacpp.opencv_core.Mat;
import org.bytedeco.javacpp.opencv_core.TermCriteria;
import org.bytedeco.javacpp.opencv_imgproc;
import org.bytedeco.javacpp.opencv_ml.TrainData;
import org.bytedeco.javacpp.opencv_ml.SVM;
import org.bytedeco.javacv.CanvasFrame;
import org.bytedeco.javacv.OpenCVFrameConverter;
import org.opencv.imgproc.Imgproc;
import javax.swing.*;
import java.io.File;
import static org.bytedeco.javacpp.opencv_core.*;
import static org.bytedeco.javacpp.opencv_core.FileStorage.READ;
import static org.bytedeco.javacpp.opencv_core.FileStorage.WRITE;
import static org.bytedeco.javacpp.opencv_imgcodecs.IMREAD_GRAYSCALE;
import static org.bytedeco.javacpp.opencv_imgcodecs.imread;
import static org.bytedeco.javacpp.opencv_imgproc.CV_THRESH_BINARY;
import static org.bytedeco.javacpp.opencv_imgproc.CV_THRESH_OTSU;
import static org.bytedeco.javacpp.opencv_ml.ROW_SAMPLE;
import static org.bytedeco.javacpp.opencv_ml.SVM.C_SVC;
import static org.bytedeco.javacpp.opencv_ml.SVM.RBF;
import static org.bytedeco.javacv.JavaCV.FLT_EPSILON;
import static org.opencv.ml.SVM.LINEAR;

/**
 * @desc : SVM对正负样本分类,得到包含车牌的图片
 * @auth : TYF
 * @date : 2019-05-22 - 13:56
 */
public class t_2 {

    //显示mat
    public static void showMatImage(Mat mat,String tit){
        OpenCVFrameConverter.ToMat converter = new OpenCVFrameConverter.ToMat();
        CanvasFrame canvas = new CanvasFrame(tit, 1);
        canvas.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
        canvas.showImage(converter.convert(mat));
    }

    //读取训练数据(0为正例、1为负例、样本名称n.png)
    public static void loadTrainData(String path0,String path1,String trainXml,String labelXml){
        //训练数据
        Mat trainData = new Mat();
        //标签
        Mat labelData = new Mat();
        File file0 = new File(path0);
        File file1 = new File(path1);
        File[] pics0 = file0.listFiles();
        File[] pics1 = file1.listFiles();
        //负例
        for(int i=1;i<=pics0.length;i++){
            File f = pics0[i-1];
            Mat temp = imread(f.getPath(),IMREAD_GRAYSCALE);//灰度图
            opencv_imgproc.threshold(temp, temp, 0, 255, CV_THRESH_OTSU+CV_THRESH_BINARY);//二值图
            Mat convertMat = new Mat();
            temp.reshape(1, 1).row(0).convertTo(convertMat, CV_32F);//转一行
            trainData.push_back(convertMat);//塞入样本
            labelData.push_back(new Mat().put(Mat.zeros(new Size(1,1),CV_32SC1)));//塞入标签0(无车牌)


        }
        //正例
        for(int i=1;i<=pics1.length;i++){
            File f = pics1[i-1];
            Mat temp = imread(f.getPath(),IMREAD_GRAYSCALE);//灰度图
            opencv_imgproc.threshold(temp, temp, 0, 255, CV_THRESH_OTSU+CV_THRESH_BINARY);//二值图
            Mat convertMat = new Mat();
            temp.reshape(1, 1).row(0).convertTo(convertMat, CV_32F);//转一行
            trainData.push_back(convertMat);//塞入样本
            labelData.push_back(new Mat().put(Mat.ones(new Size(1,1),CV_32SC1)));//塞入标签1(无车牌)
        }
        //保存为xml(注意像素点数据类型svm.train对数据类型有要求)
        opencv_core.FileStorage ft = new opencv_core.FileStorage(trainXml,WRITE);
        ft.write("tag",trainData);
        opencv_core.FileStorage fl = new opencv_core.FileStorage(labelXml,WRITE);
        fl.write("tag",labelData);
        ft.release();
        fl.release();
    }


    //训练
    public static void trainSvm(String tXml,String lXml,String path){
        //创建svm
        SVM svm = SVM.create();
        //svm类型:C_SVC/C类支撑向量分类机,NU_SVC/类支撑向量分类机,ONE_CLASS/单分类器,EPS_SVR/类支撑向量回归机,NU_SVR/类支撑向量回归机
        svm.setType(C_SVC);
        //核函数类型:LINEAR/线性,POLY/多项式,RBF/径向量,SIGMOID/二层神经收集
        svm.setKernel(LINEAR);
        //POLY内核函数的参数degree
        //svm.setDegree(0);
        //POLY/RBF/SIGMOID内核函数
        //svm.setGamma(1);
        //POLY/SIGMOID内核函数的参数coef0
        //svm.setCoef0(0);
        //NU_SVC/ONE_CLASS/NU_SVR类型SVM的参数
        //svm.setNu(0);
        //EPS_SVR类型SVM的参数
        //svm.setP(0);
        //C_SVC/EPS_SVR/NU_SVR类型SVM的参数C
        //svm.setC(1);
        //C_SVC类型SVM的可选权重
        //svm.setClassWeights();
        //终止条件(类型、迭代最大次数、阈值)
        TermCriteria ct = new TermCriteria(CV_TERMCRIT_ITER,1000,FLT_EPSILON);
        svm.setTermCriteria(ct);

        //train数据
        FileStorage ft = new FileStorage(tXml,READ);
        FileStorage fl = new FileStorage(lXml,READ);
        Mat trainMat = ft.get("tag").mat();
        Mat labelMat = fl.get("tag").mat();

        TrainData tData = TrainData.create(trainMat,ROW_SAMPLE,labelMat);//ROW_SAMPLE 样本和标签为每行

        //训练
        svm.train(tData);
        //保存结果
        svm.save(path);

    }


    //预测
    public static float testSvm(String mXml,String image){
        SVM svm = SVM.load(mXml);
        Mat temp = imread(image,IMREAD_GRAYSCALE);//灰度图
        opencv_imgproc.threshold(temp, temp, 0, 255, CV_THRESH_OTSU+CV_THRESH_BINARY);//二值图
        Mat convertMat = new Mat();
        temp.reshape(1, 1).row(0).convertTo(convertMat, CV_32F);//转一行
        float res = svm.predict(convertMat);
        return res;
    }

    //筛选车牌图片
    public static MatVector getCarPic(MatVector in){

        SVM svm = SVM.load("./target/svmModulData.xml");

        MatVector out = new MatVector();

        for(int i=0;i<in.get().length;i++){
            Mat temp = in.get()[i];
            opencv_imgproc.cvtColor(temp, temp, Imgproc.COLOR_BGR2GRAY);//灰度图
            opencv_imgproc.threshold(temp, temp, 0, 255, CV_THRESH_OTSU+CV_THRESH_BINARY);//二值图
            showMatImage(temp,"车牌:"+i);
            Mat convertMat = new Mat();
            temp.reshape(1, 1).row(0).convertTo(convertMat, CV_32F);//转一行
            float res = svm.predict(convertMat);
            System.out.println("res:"+res);
            //是正例
            if(res==1.0){
                out.push_back(temp);
            }
        }
        return out;
    }


    public static void main(String args[]){

        //获取训练、标签数据mat
        //loadTrainData("D:\\my_easypr\\trainData\\0","D:\\my_easypr\\trainData\\1","./target/svmTrainData.xml","./target/svmLabelData.xml");

        //训练
        //trainSvm("./target/svmTrainData.xml","./target/svmLabelData.xml","./target/svmModulData.xml");

        //预测
        //float res = testSvm("./target/svmModulData.xml","D:\\my_easypr\\testData\\1\\1.jpg");
        //System.out.println("res:"+res);

        int count = 0 ;
        int error = 0 ;
        //正例测试
        for(int i=1;i<=50;i++){
            float x = testSvm("./target/svmModulData.xml","D:\\my_easypr\\testData\\1\\"+i+".jpg");
            //判断正确
            if(x==1.0){
                count++;
            }
            //判断错误
            else{
                error++;
            }
        }
        System.out.println("正例测试:"+count+"正确,"+error+"错误");
        count = 0 ;
        error = 0 ;
        //负例测试
        for(int i=0;i<=127;i++){
            float x = testSvm("./target/svmModulData.xml","D:\\my_easypr\\testData\\0\\"+i+".jpg");
            //判断正确
            if(x==0.0){
                count++;
            }
            //判断错误
            else{
                error++;
            }
        }
        System.out.println("负例测试:"+count+"正确,"+error+"错误");
    }


}

除了常用的图片处理方法实现车牌检测,我也试过用yolo算法实现车牌检查,效果一般yolo还是不适合做这一类小目标的检测,使用的dl4j深度学习框架,训练了2000张图片,效果如下:

训练代码如下:

package dl4j;


import org.bytedeco.javacpp.opencv_core.Mat;
import org.bytedeco.javacpp.opencv_core.Point;
import org.bytedeco.javacpp.opencv_core.Scalar;
import org.bytedeco.javacpp.opencv_core.Size;
import org.bytedeco.javacv.CanvasFrame;
import org.bytedeco.javacv.OpenCVFrameConverter;
import org.datavec.api.io.filters.RandomPathFilter;
import org.datavec.api.records.metadata.RecordMetaDataImageURI;
import org.datavec.api.split.FileSplit;
import org.datavec.api.split.InputSplit;
import org.datavec.image.loader.NativeImageLoader;
import org.datavec.image.recordreader.objdetect.ObjectDetectionRecordReader;
import org.datavec.image.recordreader.objdetect.impl.VocLabelProvider;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.GradientNormalization;
import org.deeplearning4j.nn.conf.WorkspaceMode;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.objdetect.Yolo2OutputLayer;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.layers.objdetect.DetectedObject;
import org.deeplearning4j.nn.transferlearning.FineTuneConfiguration;
import org.deeplearning4j.nn.transferlearning.TransferLearning;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.deeplearning4j.util.ModelSerializer;
import org.deeplearning4j.zoo.model.TinyYOLO;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.learning.config.Nesterovs;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.File;
import java.net.URI;
import java.net.URISyntaxException;
import java.util.List;
import java.util.Random;
import static org.bytedeco.javacpp.opencv_core.CV_8U;
import static org.bytedeco.javacpp.opencv_imgproc.*;

/**
 * @desc : yolo算法目标检测
 * @auth : TYF
 * @data : 2019/6/12
 */
public class detectionTrain {

    private static final Logger log = LoggerFactory.getLogger(detectionTrain.class);

    public static void train() throws Exception {

        //项目根目录
        String path = new File("").getCanonicalPath();

        //yolo基本参数
        int width = 960;
        int height = 540;
        int nChannels = 3;
        int gridWidth = 30;
        int gridHeight = 17;

        //标签数量
        int nClasses = 1;

        //输出层参数
        int nBoxes = 5;
        double lambdaNoObj = 0.5;
        double lambdaCoord = 5.0;
        double[][] priorBoxes = { { 2, 2 }, { 2, 2 }, { 2, 2 }, { 2, 2 }, { 2, 2 } };
        double detectionThreshold = 0.3;

        //训练参数
        int batchSize = 2;
        int nEpochs = 50;
        double learningRate = 1e-3;
        double lrMomentum = 0.9;
        int seed = 123;
        Random rng = new Random(seed);

        String dataDir = path;
        File imageDir = new File(path+"/JPEGImages");

        log.info("load data...");
        RandomPathFilter pathFilter = new RandomPathFilter(rng) {
            @Override
            protected boolean accept(String name) {
                //按招名称读取pic对应的voc
                name = name.replace("/JPEGImages/", "/Annotations/").replace(".jpg", ".xml");
                try {
                    return new File(new URI(name)).exists();
                } catch (URISyntaxException ex) {
                    throw new RuntimeException(ex);
                }
            }
        };
        InputSplit[] data = new FileSplit(imageDir, NativeImageLoader.ALLOWED_FORMATS, rng).sample(pathFilter, 0.8, 0.2);
        InputSplit trainData = data[0];
        InputSplit testData = data[1];

        //训练集
        ObjectDetectionRecordReader recordReaderTrain = new ObjectDetectionRecordReader(height, width, nChannels, gridHeight, gridWidth,new VocLabelProvider(dataDir));
        recordReaderTrain.initialize(trainData);

        //测试集
        ObjectDetectionRecordReader recordReaderTest = new ObjectDetectionRecordReader(height, width, nChannels, gridHeight, gridWidth,new VocLabelProvider(dataDir));
        recordReaderTest.initialize(testData);

        //归一化
        RecordReaderDataSetIterator train = new RecordReaderDataSetIterator(recordReaderTrain, batchSize, 1, 1, true);
        train.setPreProcessor(new ImagePreProcessingScaler(0, 1));
        RecordReaderDataSetIterator test = new RecordReaderDataSetIterator(recordReaderTest, 1, 1, 1, true);
        test.setPreProcessor(new ImagePreProcessingScaler(0, 1));

        //下载预训练模型
        ComputationGraph model;
        String modelFilename = path+"/model.zip";
        if (new File(modelFilename).exists()) {
            log.info("load model...");
            model = ModelSerializer.restoreComputationGraph(modelFilename);
        } else {
            log.info("create model...");
            //预训练模型
            ComputationGraph pretrained = (ComputationGraph)TinyYOLO.builder().build().initPretrained();
            INDArray priors = Nd4j.create(priorBoxes);

            //修改与训练模型结构
            FineTuneConfiguration fineTuneConf = new FineTuneConfiguration
                    .Builder().seed(seed)
                    //优化算法:随机梯度下降
                    .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
                    //梯度标准化算法:RenormalizeL2PerLayer梯度(防止梯度消失和梯度爆炸)
                    .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
                    .gradientNormalizationThreshold(1.0)
                    //更新器:Nesterovs
                    .updater(new Adam.Builder().learningRate(learningRate).build())
                    .updater(new Nesterovs.Builder().learningRate(learningRate).momentum(lrMomentum).build())
                    .activation(Activation.IDENTITY)
                    //内存管理模式:工作区
                    .trainingWorkspaceMode(WorkspaceMode.ENABLED)
                    .inferenceWorkspaceMode(WorkspaceMode.ENABLED)
                    .build();

            //迁移学习
            model = new TransferLearning
                    .GraphBuilder(pretrained).
                    fineTuneConfiguration(fineTuneConf).
                    removeVertexKeepConnections("conv2d_9")
                    .addLayer("convolution2d_9",new ConvolutionLayer.Builder(1, 1).nIn(1024).nOut(nBoxes * (5 + nClasses)).stride(1, 1).convolutionMode(ConvolutionMode.Same).weightInit(WeightInit.UNIFORM).hasBias(false).activation(Activation.IDENTITY).build(), "leaky_re_lu_8")
                    .addLayer("outputs", new Yolo2OutputLayer.Builder().lambbaNoObj(lambdaNoObj).lambdaCoord(lambdaCoord).boundingBoxPriors(priors).build(),"convolution2d_9")
                    .setOutputs("outputs")
                    .build();

            System.out.println(model.summary(InputType.convolutional(height, width, nChannels)));

            log.info("train...");
            model.setListeners(new ScoreIterationListener(1));
            for (int i = 0; i < nEpochs; i++) {
                train.reset();
                while (train.hasNext()) {
                    model.fit(train.next());
                }
                log.info("*** Completed epoch {} ***", i);
            }
            //保存模型
            ModelSerializer.writeModel(model, modelFilename, true);
            //关机
            Runtime.getRuntime().exec("shutdown -s -t 30");
        }

        //模型检测可视化
        NativeImageLoader imageLoader = new NativeImageLoader();
        CanvasFrame frame = new CanvasFrame("detectionTrain");
        OpenCVFrameConverter.ToMat converter = new OpenCVFrameConverter.ToMat();
        org.deeplearning4j.nn.layers.objdetect.Yolo2OutputLayer yout = (org.deeplearning4j.nn.layers.objdetect.Yolo2OutputLayer) model.getOutputLayer(0);
        List<String> labels = train.getLabels();
        test.setCollectMetaData(true);
        while (test.hasNext() && frame.isVisible()) {
            org.nd4j.linalg.dataset.DataSet ds = test.next();
            RecordMetaDataImageURI metadata = (RecordMetaDataImageURI) ds.getExampleMetaData().get(0);
            INDArray features = ds.getFeatures();
            INDArray results = model.outputSingle(features);
            List<DetectedObject> objs = yout.getPredictedObjects(results, detectionThreshold);
            File file = new File(metadata.getURI());
            log.info(file.getName() + ": " + objs);
            Mat mat = imageLoader.asMat(features);
            Mat convertedMat = new Mat();
            mat.convertTo(convertedMat, CV_8U, 255, 0);
            int w = width;
            int h = height;
            Mat image = new Mat();
            resize(convertedMat, image, new Size(w, h));
            for (DetectedObject obj : objs) {
                double[] xy1 = obj.getTopLeftXY();
                double[] xy2 = obj.getBottomRightXY();
                String label = labels.get(obj.getPredictedClass());
                int x1 = (int) Math.round(w * xy1[0] / gridWidth);
                int y1 = (int) Math.round(h * xy1[1] / gridHeight);
                int x2 = (int) Math.round(w * xy2[0] / gridWidth);
                int y2 = (int) Math.round(h * xy2[1] / gridHeight);
                rectangle(image, new Point(x1, y1), new Point(x2, y2), Scalar.RED);
                putText(image, label, new Point(x1-80, y2+30), FONT_HERSHEY_DUPLEX, 1, Scalar.RED);
            }
            frame.setTitle(new File(metadata.getURI()).getName() + " - detectionTrain");
            frame.setCanvasSize(w, h);
            frame.showImage(converter.convert(image));
            frame.waitKey();
        }
        frame.dispose();
    }

    public static void main(String[] args) throws Exception {
        train();
    }

}

  • 2
    点赞
  • 18
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 5
    评论
车牌识别算法(License Plate Recognition Algorithm)是一种基于计算机视觉技术的自动车牌识别系统。该算法主要用于从图像或视频中准确、快速地检测和识别出车牌信息。下面将简要介绍车牌识别算法的工作原理和应用。 车牌识别算法主要分为三个步骤:车牌定位、字符分割和字符识别。首先,车牌定位阶段通过图像处理技术,识别出图像中的车牌位置和边界框。其次,字符分割阶段将车牌中的字符进行切割,以便进一步的字符识别。最后,字符识别阶段采用模式识别或深度学习方法,对切割后的字符进行识别和分类。 车牌识别算法主要应用于交通管理、安防监控、智能停车等场景。在交通管理领域,车牌识别算法能够实现自动识别违章车辆、快速查找失窃车辆等功能,提高交通安全和便捷性。在安防监控领域,车牌识别算法可以用于对进出小区、停车场等场所的车辆进行自动识别和记录,提升安全管理水平。在智能停车系统中,车牌识别算法可以实现车辆自动进出停车场、准确计费等功能,提高停车场的使用效率和用户体验。 然而,车牌识别算法仍然面临一些挑战,如复杂背景环境、光照变化、车牌多样性等问题,这些都会对车牌识别算法的准确性和鲁棒性带来一定的影响。因此,针对这些挑战,研究人员不断提出新的算法和技术来改进车牌识别的性能。例如,采用深度学习方法可以提高对复杂背景和光照变化的适应性,使用大规模车牌数据集进行训练可以提高识别准确率和鲁棒性。 总的来说,车牌识别算法是一项具有广泛应用前景的技术,并且随着计算机视觉和人工智能等领域的不断发展,车牌识别算法将进一步提高其准确性和实时性,为各个领域的应用提供更好的支持。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

0x13

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值