2023-7-19 日更新:
添加了安卓推理测试。使用 onnxruntime 和 ncnn 部署都差不多。
2023-4-17 日更新: 在 yolov5上面添加了一个 key point 检测出4个车牌关键点,结果投影变换后再使用 crnn 进行字符识别。参考代码:
https://github.com/we0091234/Chinese_license_plate_detection_recognition.git
原理比较简单,效果见上图,主要是投影变换,下面是java推理代码。
import ai.onnxruntime.*;
import org.opencv.core.*;
import org.opencv.core.Point;
import org.opencv.imgcodecs.Imgcodecs;
import org.opencv.imgproc.Imgproc;
import javax.imageio.ImageIO;
import javax.swing.*;
import java.awt.*;
import java.awt.image.BufferedImage;
import java.io.ByteArrayInputStream;
import java.io.File;
import java.nio.FloatBuffer;
import java.util.*;
import java.util.List;
/**
* @desc : 车牌检测 + 车牌字符/颜色识别
* @auth : tyf
* @date : 2023-04-26 18:06:14
*/
public class yolov5_car_plate {
// 模型1
public static OrtEnvironment env1;
public static OrtSession session1;
// 模型2
public static OrtEnvironment env2;
public static OrtSession session2;
// 记录一个图片的信息
public static class ImageObj{
// 图片模型尺寸用于推理
Mat src;
// 图片原始尺寸用于绘图
Mat background;
// 过滤后的边框信息
List<float[]> data;
// 投影变换后的车牌矩阵
List<Mat> platesMat = new ArrayList<>();
// 投影变换后的车牌
List<String> platesStr = new ArrayList<>();
// 车牌的颜色
List<Color> platesColor = new ArrayList<>();
// 颜色
Scalar color1 = new Scalar(0, 0, 255);
Scalar color2 = new Scalar(0, 255, 0);
// 投影变换后车牌的宽高,也就是第二个模型的输入尺寸
int plateWidth = 168;
int plateHeight = 48;
// 原始图片尺寸,也就是第一个模型的输入尺寸
int picWidth = 640;
int picHeight = 640;
// 车牌类别
char[] plateChar = new char[]{
'#','京','沪','津','渝','冀','晋','蒙','辽','吉',
'黑','苏','浙','皖','闽','赣','鲁','豫','鄂','湘',
'粤','桂','琼','川','贵','云','藏','陕','甘','青',
'宁','新','学','警','港','澳','挂','使','领','民',
'航','危','0','1','2','3','4','5','6','7',
'8','9','A','B','C','D','E','F','G','H',
'J','K','L','M','N','P','Q','R','S','T',
'U','V','W','X','Y','Z','险','品',
};
// 车牌颜色类别 color=['黑色','蓝色','绿色','白色','黄色']
Color [] plateScalar = new Color []{
Color.BLACK,
Color.BLUE,
Color.GREEN,
Color.WHITE,
Color.YELLOW
};
// 宽高缩放比
float wScale;
float hScale;
public ImageObj(String img) {
// 原始图像
this.background = readImg(img);
// 缩放过后的图像
this.src = resizeWithoutPadding(this.background,this.picWidth,this.picHeight);
// 保存缩放比
this.wScale = Float.valueOf(src.width())/ Float.valueOf(background.width());
this.hScale = Float.valueOf(src.height())/Float.valueOf(background.height());
}
public void setDataAndFilter(float[][] output){
// xywh objscore class1 class2 x1y1 x2y2 x3y3 x4y4
float confThreshold = 0.75f;
float nmsThreshold = 0.45f; // 车牌识别省略nms
List<float[]> temp = new ArrayList<>();
// 置信度过滤
for(int i=0;i<output.length;i++){
float[] obj = output[i];
float x = obj[0];
float y = obj[1];
float w = obj[2];
float h = obj[3];
float score = obj[4];
float x1 = obj[5];
float y1 = obj[6];
float x2 = obj[7];
float y2 = obj[8];
float x3 = obj[9];
float y3 = obj[10];
float x4 = obj[11];
float y4 = obj[12];
float class1 = obj[13];
float class2 = obj[14];
if(score>=confThreshold){
// 边框坐标
float[] xyxy = xywh2xyxy(new float[]{x,y,w,h},this.picWidth,this.picHeight);
// 类别1或者2
float clazz = class1>class2?1:2;
// 类别概率
float clazzScore = class1>class2?class1:class2;
// 关键点坐标
temp.add(new float[]{
xyxy[0], xyxy[1], xyxy[2], xyxy[3], x1, y1, x2, y2, x3, y3, x4, y4,clazz,clazzScore
});
}
}
// 交并比过滤
// 先按照概率排序
temp.sort((o1, o2) -> Float.compare(o2[13],o1[13]));
// 保存最终的过滤结果
List<float[]> out = new ArrayList<>();
while (!temp.isEmpty()){
float[] max = temp.get(0);
out.add(max);
Iterator<float[]> it = temp.iterator();
while (it.hasNext()) {
float[] obj = it.next();
// 交并比
double iou = calculateIoU(
new float[]{max[0],max[1],max[2],max[3]},
new float[]{obj[0],obj[1],obj[2],obj[3]}
);
if (iou > nmsThreshold) {
it.remove();
}
}
}
// 保存最终的边框
this.data = out;
}
// 对所有车牌关键点进行透视变换,拉成一个矩形
public void transform(){
// 首先对每个车牌目标进行关键点透视变换
this.data.stream().forEach(n->{
float key_point_x1 = n[4];
float key_point_y1 = n[5];
float key_point_x2 = n[6];
float key_point_y2 = n[7];
float key_point_x3 = n[8];
float key_point_y3 = n[9];
float key_point_x4 = n[10];
float key_point_y4 = n[11];
Point[] srcPoints = new Point[4];
Point p1 = new Point(Float.valueOf(key_point_x1).intValue(), Float.valueOf(key_point_y1).intValue());
Point p2 = new Point(Float.valueOf(key_point_x2).intValue(), Float.valueOf(key_point_y2).intValue());
Point p3 = new Point(Float.valueOf(key_point_x3).intValue(), Float.valueOf(key_point_y3).intValue());
Point p4 = new Point(Float.valueOf(key_point_x4).intValue(), Float.valueOf(key_point_y4).intValue());
srcPoints[0] = p1;
srcPoints[1] = p2;
srcPoints[2] = p3;
srcPoints[3] = p4;
// 定义透视变换后的目标矩形的四个角点,指定车牌的宽和高
Point[] dstPoints = new Point[4];
dstPoints[0] = new Point(0, 0);
dstPoints[1] = new Point(plateWidth, 0);
dstPoints[2] = new Point(plateWidth, plateHeight);
dstPoints[3] = new Point(0, plateHeight);
// 计算透视变换矩阵
MatOfPoint2f in1 = new MatOfPoint2f(srcPoints);
MatOfPoint2f in2 = new MatOfPoint2f(dstPoints);
Mat M = Imgproc.getPerspectiveTransform(in1, in2);
// 进行透视变换
Mat warped = new Mat();
Imgproc.warpPerspective(src, warped, M, new Size(plateWidth, plateHeight));
// 保存透视变换得到的车牌
platesMat.add(warped);
});
}
public void drawBox(){
// 在原始图片尺寸上绘制,需要坐标转换
// 遍历每个车牌框
for(int i=0; i<this.data.size() ; i++ ){
float[] n = data.get(i);
// 位置信息
float x1 = n[0] / wScale;
float y1 = n[1] / hScale;
float x2 = n[2] / wScale;
float y2 = n[3] / hScale;
float key_point_x1 = n[4] / wScale;
float key_point_y1 = n[5] / hScale;
float key_point_x2 = n[6] / wScale;
float key_point_y2 = n[7] / hScale;
float key_point_x3 = n[8] / wScale;
float key_point_y3 = n[9] / hScale;
float key_point_x4 = n[10] / wScale;
float key_point_y4 = n[11] / hScale;
float clazz = n[12];
float clazzScore = n[13];
// 画边框
Imgproc.rectangle(
background,
new Point(Float.valueOf(x1).intValue(), Float.valueOf(y1).intValue()),
new Point(Float.valueOf(x2).intValue(), Float.valueOf(y2).intValue()),
color1,
2);
// 画关键点四个
Imgproc.circle(
background,
new Point(Float.valueOf(key_point_x1).intValue(), Float.valueOf(key_point_y1).intValue()),
3, // 半径
color2,
2);
Imgproc.circle(
background,
new Point(Float.valueOf(key_point_x2).intValue(), Float.valueOf(key_point_y2).intValue()),
3, // 半径
color2,
2);
Imgproc.circle(
background,
new Point(Float.valueOf(key_point_x3).intValue(), Float.valueOf(key_point_y3).intValue()),
3, // 半径
color2,
2);
Imgproc.circle(
background,
new Point(Float.valueOf(key_point_x4).intValue(), Float.valueOf(key_point_y4).intValue()),
3, // 半径
color2,
2);
// 获取车牌
String number = platesStr.get(i);
}
}
}
// 环境初始化
public static void init1(String weight) throws Exception{
// opencv 库
System.loadLibrary(Core.NATIVE_LIBRARY_NAME);
env1 = OrtEnvironment.getEnvironment();
session1 = env1.createSession(weight, new OrtSession.SessionOptions());
}
// 环境初始化
public static void init2(String weight) throws Exception{
// opencv 库
System.loadLibrary(Core.NATIVE_LIBRARY_NAME);
env2 = OrtEnvironment.getEnvironment();
session2 = env2.createSession(weight, new OrtSession.SessionOptions());
}
// Mat 转 BufferedImage
public static BufferedImage mat2BufferedImage(Mat mat){
BufferedImage bufferedImage = null;
try {
// 将Mat对象转换为字节数组
MatOfByte matOfByte = new MatOfByte();
Imgcodecs.imencode(".jpg", mat, matOfByte);
// 创建Java的ByteArrayInputStream对象
byte[] byteArray = matOfByte.toArray();
ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(byteArray);
// 使用ImageIO读取ByteArrayInputStream并将其转换为BufferedImage对象
bufferedImage = ImageIO.read(byteArrayInputStream);
} catch (Exception e) {
e.printStackTrace();
}
return bufferedImage;
}
public static float[] xywh2xyxy(float[] bbox,float maxWidth,float maxHeight) {
// 中心点坐标
float x = bbox[0];
float y = bbox[1];
float w = bbox[2];
float h = bbox[3];
// 计算
float x1 = x - w * 0.5f;
float y1 = y - h * 0.5f;
float x2 = x + w * 0.5f;
float y2 = y + h * 0.5f;
// 限制在图片区域内
return new float[]{
x1 < 0 ? 0 : x1,
y1 < 0 ? 0 : y1,
x2 > maxWidth ? maxWidth:x2,
y2 > maxHeight? maxHeight:y2};
}
public static Mat readImg(String path){
Mat img = Imgcodecs.imread(path);
return img;
}
public static float[] whc2cwh(float[] src) {
float[] chw = new float[src.length];
int j = 0;
for (int ch = 0; ch < 3; ++ch) {
for (int i = ch; i < src.length; i += 3) {
chw[j] = src[i];
j++;
}
}
return chw;
}
public static OnnxTensor transferTensor(Mat dst,int channels,int netWidth,int netHeight){
// BGR -> RGB
Imgproc.cvtColor(dst, dst, Imgproc.COLOR_BGR2RGB);
// 归一化 0-255 转 0-1
dst.convertTo(dst, CvType.CV_32FC1, 1. / 255);
// 初始化一个输入数组 channels * netWidth * netHeight
float[] whc = new float[ Long.valueOf(channels).intValue() * Long.valueOf(netWidth).intValue() * Long.valueOf(netHeight).intValue() ];
dst.get(0, 0, whc);
// 得到最终的图片转 float 数组
float[] chw = whc2cwh(whc);
// 创建 onnxruntime 需要的 tensor
// 传入输入的图片 float 数组并指定数组shape
OnnxTensor tensor = null;
try {
tensor = OnnxTensor.createTensor(env1, FloatBuffer.wrap(chw), new long[]{1,channels,netHeight,netWidth});
}
catch (Exception e){
e.printStackTrace();
System.exit(0);
}
return tensor;
}
public static OnnxTensor transferTensor2(Mat dst,int channels,int netWidth,int netHeight){
// BGR -> RGB
Imgproc.cvtColor(dst, dst, Imgproc.COLOR_BGR2RGB);
double[] meanValue = {0.588, 0.588, 0.588};
double[] stdValue = {0.193, 0.193, 0.193};
// Convert image to float and normalize using mean and standard deviation values
dst.convertTo(dst, CvType.CV_32FC3, 1.0 / 255.0);
Core.subtract(dst, new Scalar(meanValue), dst);
Core.divide(dst, new Scalar(stdValue), dst);
// 初始化一个输入数组 channels * netWidth * netHeight
float[] whc = new float[ Long.valueOf(channels).intValue() * Long.valueOf(netWidth).intValue() * Long.valueOf(netHeight).intValue() ];
dst.get(0, 0, whc);
// 得到最终的图片转 float 数组
float[] chw = whc2cwh(whc);
// 创建 onnxruntime 需要的 tensor
// 传入输入的图片 float 数组并指定数组shape
OnnxTensor tensor = null;
try {
tensor = OnnxTensor.createTensor(env1, FloatBuffer.wrap(chw), new long[]{1,channels,netHeight,netWidth});
}
catch (Exception e){
e.printStackTrace();
System.exit(0);
}
return tensor;
}
// 计算两个框的交并比
private static double calculateIoU(float[] box1, float[] box2) {
// getXYXY() 返回 xmin-0 ymin-1 xmax-2 ymax-3
double x1 = Math.max(box1[0], box2[0]);
double y1 = Math.max(box1[1], box2[1]);
double x2 = Math.min(box1[2], box2[2]);
double y2 = Math.min(box1[3], box2[3]);
double intersectionArea = Math.max(0, x2 - x1 + 1) * Math.max(0, y2 - y1 + 1);
double box1Area = (box1[2] - box1[0] + 1) * (box1[3] - box1[1] + 1);
double box2Area = (box2[2] - box2[0] + 1) * (box2[3] - box2[1] + 1);
double unionArea = box1Area + box2Area - intersectionArea;
return intersectionArea / unionArea;
}
public static int getMaxIndex(float[] array) {
int maxIndex = 0;
float maxVal = array[0];
for (int i = 1; i < array.length; i++) {
if (array[i] > maxVal) {
maxVal = array[i];
maxIndex = i;
}
}
return maxIndex;
}
// 将一个 src_mat 修改尺寸后存储到 dst_mat 中
public static Mat resizeWithoutPadding(Mat src, int netWidth, int netHeight) {
// 调整图像大小
Mat resizedImage = new Mat();
Size size = new Size(netWidth, netHeight);
Imgproc.resize(src, resizedImage, size, 0, 0, Imgproc.INTER_AREA);
return resizedImage;
}
// 车牌检测,以及4个关键点
public static void doDetect(ImageObj imageObj) throws Exception{
// 输入矩阵
Mat in = imageObj.src.clone();
// 转为tensor
OnnxTensor tensor = transferTensor(in,3,imageObj.picWidth,imageObj.picHeight);
// 推理
OrtSession.Result res = session1.run(Collections.singletonMap("input", tensor));
// 解析 output -> [1, 25200, 15] -> FLOAT
float[][] data = ((float[][][])(res.get(0)).getValue())[0];
// 根据置信度、交并比过滤
imageObj.setDataAndFilter(data);
}
// 识别车牌
public static void doRecect(ImageObj imageObj){
// 先将关键点透视变换为矩形方便识别,目标尺寸就是第二个模型的输入 168*48
imageObj.transform();
// 第二个模型是crnn 输入投影变换后的车牌图片即可
imageObj.platesMat.stream().forEach(plate->{
try {
OnnxTensor tensor = transferTensor2(plate.clone(),3,plate.width(),plate.height());
OrtSession.Result res = session2.run(Collections.singletonMap("images", tensor));
float[][] data1 = ((float[][][])(res.get(0)).getValue())[0];
// 遍历每个字符
char last = '-';
List<Character> chars = new ArrayList<>();
for(int i=0;i<data1.length;i++){
int maxIndex = getMaxIndex(data1[i]);
char maxName = imageObj.plateChar[maxIndex];
if( maxIndex!=0 && maxName!=last ){
chars.add(maxName);
}
last = maxName;
}
StringBuffer car = new StringBuffer();
chars.stream().forEach(n->{
car.append(n);
});
imageObj.platesStr.add(car.toString());
// 5 代表五个颜色
float[] data2 = ((float[][])(res.get(1)).getValue())[0];
int maxIndex = getMaxIndex(data2);
Color color = imageObj.plateScalar[maxIndex];// 从类别下表中查找
imageObj.platesColor.add(color);
}
catch (Exception e){
e.printStackTrace();
}
});
}
// 弹窗显示所有信息
public static void showJpanel(ImageObj img){
JFrame frame = new JFrame("Car");
// 一行两列
JPanel parent = new JPanel();
// 显示图片
JPanel p1 = new JPanel();
p1.add(new JLabel(new ImageIcon(mat2BufferedImage(img.background))));
// 显示车牌子图片
JPanel p2 = new JPanel(new FlowLayout(FlowLayout.LEFT, 20, 20));
JPanel sub = new JPanel(new GridLayout(img.platesMat.size()+1, 1, 0, 5));
// sub.setLayout(new BoxLayout(sub, BoxLayout.Y_AXIS));
JPanel title = new JPanel(new GridLayout(1,3,10,10));
JLabel label1 = new JLabel("投影变换");
label1.setHorizontalAlignment(JLabel.CENTER);
title.add(label1);
JLabel label2 = new JLabel("车牌号");
label2.setHorizontalAlignment(JLabel.CENTER);
title.add(label2);
JLabel label3 = new JLabel("颜色");
label3.setHorizontalAlignment(JLabel.CENTER);
title.add(label3);
sub.add(title);
for(int i=0;i<img.platesMat.size();i++){
// 每个车牌占一行
JPanel line = new JPanel(new GridLayout(1,3,10,10));
// 车牌图片
JLabel jLabel1 = new JLabel(new ImageIcon(mat2BufferedImage(img.platesMat.get(i))));
// 车牌号
JLabel jLabel2 = new JLabel(img.platesStr.get(i));
// 车牌颜色
JLabel jLabel3 = new JLabel("█");
jLabel3.setForeground(img.platesColor.get(i));
// 居中
jLabel1.setHorizontalAlignment(JLabel.CENTER);
jLabel2.setHorizontalAlignment(JLabel.CENTER);
jLabel3.setHorizontalAlignment(JLabel.CENTER);
line.add(jLabel1);
line.add(jLabel2);
line.add(jLabel3);
sub.add(line);
}
p2.add(sub);
parent.add(p1);
parent.add(p2);
frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
frame.getContentPane().add(parent);
frame.pack();
frame.setVisible(true);
}
public static void main(String[] args) throws Exception{
// 模型初始化 车牌检测、车牌识别
init1(new File("").getCanonicalPath()+"\\src\\main\\resources\\deeplearning\\yolov5_car_plate\\plate_detect.onnx");
init2(new File("").getCanonicalPath()+"\\src\\main\\resources\\deeplearning\\yolov5_car_plate\\plate_rec_color.onnx");
// 原始图片
ImageObj img = new ImageObj(new File("").getCanonicalPath()+"\\src\\main\\resources\\deeplearning\\yolov5_car_plate\\car.png");
// 车牌区域检测
doDetect(img);
// 车牌识别
doRecect(img);
// 原图绘制边框
img.drawBox();
// 弹窗显示
showJpanel(img);
}
}
2019-5-22 更新:
一般来说摄像头的头端的车牌识别用opencv纯图像处理可以达到300-400毫秒一张已经是很成熟商用的方案了。
其他的识别方案比如地感+云识别也是500毫秒左右一张。
尝试过用深度学习的方法来做车牌检测+车牌识别,比较慢。
开源的easypr用的svm等比较古老的方法,尝试了一把识别一张在400毫秒左右。直接复用了easypr的训练结果,效果就是这样子:
先纯图像处理方法把车牌区域检测出来,然后第二部确定该区域是否有车牌,本质是个二分类问题。javacv代码如下:
package com.ist.EasyPr;
import org.bytedeco.javacpp.opencv_core;
import org.bytedeco.javacpp.opencv_core.Size;
import org.bytedeco.javacpp.opencv_core.Mat;
import org.bytedeco.javacpp.opencv_core.TermCriteria;
import org.bytedeco.javacpp.opencv_imgproc;
import org.bytedeco.javacpp.opencv_ml.TrainData;
import org.bytedeco.javacpp.opencv_ml.SVM;
import org.bytedeco.javacv.CanvasFrame;
import org.bytedeco.javacv.OpenCVFrameConverter;
import org.opencv.imgproc.Imgproc;
import javax.swing.*;
import java.io.File;
import static org.bytedeco.javacpp.opencv_core.*;
import static org.bytedeco.javacpp.opencv_core.FileStorage.READ;
import static org.bytedeco.javacpp.opencv_core.FileStorage.WRITE;
import static org.bytedeco.javacpp.opencv_imgcodecs.IMREAD_GRAYSCALE;
import static org.bytedeco.javacpp.opencv_imgcodecs.imread;
import static org.bytedeco.javacpp.opencv_imgproc.CV_THRESH_BINARY;
import static org.bytedeco.javacpp.opencv_imgproc.CV_THRESH_OTSU;
import static org.bytedeco.javacpp.opencv_ml.ROW_SAMPLE;
import static org.bytedeco.javacpp.opencv_ml.SVM.C_SVC;
import static org.bytedeco.javacpp.opencv_ml.SVM.RBF;
import static org.bytedeco.javacv.JavaCV.FLT_EPSILON;
import static org.opencv.ml.SVM.LINEAR;
/**
* @desc : SVM对正负样本分类,得到包含车牌的图片
* @auth : TYF
* @date : 2019-05-22 - 13:56
*/
public class t_2 {
//显示mat
public static void showMatImage(Mat mat,String tit){
OpenCVFrameConverter.ToMat converter = new OpenCVFrameConverter.ToMat();
CanvasFrame canvas = new CanvasFrame(tit, 1);
canvas.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
canvas.showImage(converter.convert(mat));
}
//读取训练数据(0为正例、1为负例、样本名称n.png)
public static void loadTrainData(String path0,String path1,String trainXml,String labelXml){
//训练数据
Mat trainData = new Mat();
//标签
Mat labelData = new Mat();
File file0 = new File(path0);
File file1 = new File(path1);
File[] pics0 = file0.listFiles();
File[] pics1 = file1.listFiles();
//负例
for(int i=1;i<=pics0.length;i++){
File f = pics0[i-1];
Mat temp = imread(f.getPath(),IMREAD_GRAYSCALE);//灰度图
opencv_imgproc.threshold(temp, temp, 0, 255, CV_THRESH_OTSU+CV_THRESH_BINARY);//二值图
Mat convertMat = new Mat();
temp.reshape(1, 1).row(0).convertTo(convertMat, CV_32F);//转一行
trainData.push_back(convertMat);//塞入样本
labelData.push_back(new Mat().put(Mat.zeros(new Size(1,1),CV_32SC1)));//塞入标签0(无车牌)
}
//正例
for(int i=1;i<=pics1.length;i++){
File f = pics1[i-1];
Mat temp = imread(f.getPath(),IMREAD_GRAYSCALE);//灰度图
opencv_imgproc.threshold(temp, temp, 0, 255, CV_THRESH_OTSU+CV_THRESH_BINARY);//二值图
Mat convertMat = new Mat();
temp.reshape(1, 1).row(0).convertTo(convertMat, CV_32F);//转一行
trainData.push_back(convertMat);//塞入样本
labelData.push_back(new Mat().put(Mat.ones(new Size(1,1),CV_32SC1)));//塞入标签1(无车牌)
}
//保存为xml(注意像素点数据类型svm.train对数据类型有要求)
opencv_core.FileStorage ft = new opencv_core.FileStorage(trainXml,WRITE);
ft.write("tag",trainData);
opencv_core.FileStorage fl = new opencv_core.FileStorage(labelXml,WRITE);
fl.write("tag",labelData);
ft.release();
fl.release();
}
//训练
public static void trainSvm(String tXml,String lXml,String path){
//创建svm
SVM svm = SVM.create();
//svm类型:C_SVC/C类支撑向量分类机,NU_SVC/类支撑向量分类机,ONE_CLASS/单分类器,EPS_SVR/类支撑向量回归机,NU_SVR/类支撑向量回归机
svm.setType(C_SVC);
//核函数类型:LINEAR/线性,POLY/多项式,RBF/径向量,SIGMOID/二层神经收集
svm.setKernel(LINEAR);
//POLY内核函数的参数degree
//svm.setDegree(0);
//POLY/RBF/SIGMOID内核函数
//svm.setGamma(1);
//POLY/SIGMOID内核函数的参数coef0
//svm.setCoef0(0);
//NU_SVC/ONE_CLASS/NU_SVR类型SVM的参数
//svm.setNu(0);
//EPS_SVR类型SVM的参数
//svm.setP(0);
//C_SVC/EPS_SVR/NU_SVR类型SVM的参数C
//svm.setC(1);
//C_SVC类型SVM的可选权重
//svm.setClassWeights();
//终止条件(类型、迭代最大次数、阈值)
TermCriteria ct = new TermCriteria(CV_TERMCRIT_ITER,1000,FLT_EPSILON);
svm.setTermCriteria(ct);
//train数据
FileStorage ft = new FileStorage(tXml,READ);
FileStorage fl = new FileStorage(lXml,READ);
Mat trainMat = ft.get("tag").mat();
Mat labelMat = fl.get("tag").mat();
TrainData tData = TrainData.create(trainMat,ROW_SAMPLE,labelMat);//ROW_SAMPLE 样本和标签为每行
//训练
svm.train(tData);
//保存结果
svm.save(path);
}
//预测
public static float testSvm(String mXml,String image){
SVM svm = SVM.load(mXml);
Mat temp = imread(image,IMREAD_GRAYSCALE);//灰度图
opencv_imgproc.threshold(temp, temp, 0, 255, CV_THRESH_OTSU+CV_THRESH_BINARY);//二值图
Mat convertMat = new Mat();
temp.reshape(1, 1).row(0).convertTo(convertMat, CV_32F);//转一行
float res = svm.predict(convertMat);
return res;
}
//筛选车牌图片
public static MatVector getCarPic(MatVector in){
SVM svm = SVM.load("./target/svmModulData.xml");
MatVector out = new MatVector();
for(int i=0;i<in.get().length;i++){
Mat temp = in.get()[i];
opencv_imgproc.cvtColor(temp, temp, Imgproc.COLOR_BGR2GRAY);//灰度图
opencv_imgproc.threshold(temp, temp, 0, 255, CV_THRESH_OTSU+CV_THRESH_BINARY);//二值图
showMatImage(temp,"车牌:"+i);
Mat convertMat = new Mat();
temp.reshape(1, 1).row(0).convertTo(convertMat, CV_32F);//转一行
float res = svm.predict(convertMat);
System.out.println("res:"+res);
//是正例
if(res==1.0){
out.push_back(temp);
}
}
return out;
}
public static void main(String args[]){
//获取训练、标签数据mat
//loadTrainData("D:\\my_easypr\\trainData\\0","D:\\my_easypr\\trainData\\1","./target/svmTrainData.xml","./target/svmLabelData.xml");
//训练
//trainSvm("./target/svmTrainData.xml","./target/svmLabelData.xml","./target/svmModulData.xml");
//预测
//float res = testSvm("./target/svmModulData.xml","D:\\my_easypr\\testData\\1\\1.jpg");
//System.out.println("res:"+res);
int count = 0 ;
int error = 0 ;
//正例测试
for(int i=1;i<=50;i++){
float x = testSvm("./target/svmModulData.xml","D:\\my_easypr\\testData\\1\\"+i+".jpg");
//判断正确
if(x==1.0){
count++;
}
//判断错误
else{
error++;
}
}
System.out.println("正例测试:"+count+"正确,"+error+"错误");
count = 0 ;
error = 0 ;
//负例测试
for(int i=0;i<=127;i++){
float x = testSvm("./target/svmModulData.xml","D:\\my_easypr\\testData\\0\\"+i+".jpg");
//判断正确
if(x==0.0){
count++;
}
//判断错误
else{
error++;
}
}
System.out.println("负例测试:"+count+"正确,"+error+"错误");
}
}
除了常用的图片处理方法实现车牌检测,我也试过用yolo算法实现车牌检查,效果一般yolo还是不适合做这一类小目标的检测,使用的dl4j深度学习框架,训练了2000张图片,效果如下:
训练代码如下:
package dl4j;
import org.bytedeco.javacpp.opencv_core.Mat;
import org.bytedeco.javacpp.opencv_core.Point;
import org.bytedeco.javacpp.opencv_core.Scalar;
import org.bytedeco.javacpp.opencv_core.Size;
import org.bytedeco.javacv.CanvasFrame;
import org.bytedeco.javacv.OpenCVFrameConverter;
import org.datavec.api.io.filters.RandomPathFilter;
import org.datavec.api.records.metadata.RecordMetaDataImageURI;
import org.datavec.api.split.FileSplit;
import org.datavec.api.split.InputSplit;
import org.datavec.image.loader.NativeImageLoader;
import org.datavec.image.recordreader.objdetect.ObjectDetectionRecordReader;
import org.datavec.image.recordreader.objdetect.impl.VocLabelProvider;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.GradientNormalization;
import org.deeplearning4j.nn.conf.WorkspaceMode;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.objdetect.Yolo2OutputLayer;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.layers.objdetect.DetectedObject;
import org.deeplearning4j.nn.transferlearning.FineTuneConfiguration;
import org.deeplearning4j.nn.transferlearning.TransferLearning;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.deeplearning4j.util.ModelSerializer;
import org.deeplearning4j.zoo.model.TinyYOLO;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.learning.config.Nesterovs;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.File;
import java.net.URI;
import java.net.URISyntaxException;
import java.util.List;
import java.util.Random;
import static org.bytedeco.javacpp.opencv_core.CV_8U;
import static org.bytedeco.javacpp.opencv_imgproc.*;
/**
* @desc : yolo算法目标检测
* @auth : TYF
* @data : 2019/6/12
*/
public class detectionTrain {
private static final Logger log = LoggerFactory.getLogger(detectionTrain.class);
public static void train() throws Exception {
//项目根目录
String path = new File("").getCanonicalPath();
//yolo基本参数
int width = 960;
int height = 540;
int nChannels = 3;
int gridWidth = 30;
int gridHeight = 17;
//标签数量
int nClasses = 1;
//输出层参数
int nBoxes = 5;
double lambdaNoObj = 0.5;
double lambdaCoord = 5.0;
double[][] priorBoxes = { { 2, 2 }, { 2, 2 }, { 2, 2 }, { 2, 2 }, { 2, 2 } };
double detectionThreshold = 0.3;
//训练参数
int batchSize = 2;
int nEpochs = 50;
double learningRate = 1e-3;
double lrMomentum = 0.9;
int seed = 123;
Random rng = new Random(seed);
String dataDir = path;
File imageDir = new File(path+"/JPEGImages");
log.info("load data...");
RandomPathFilter pathFilter = new RandomPathFilter(rng) {
@Override
protected boolean accept(String name) {
//按招名称读取pic对应的voc
name = name.replace("/JPEGImages/", "/Annotations/").replace(".jpg", ".xml");
try {
return new File(new URI(name)).exists();
} catch (URISyntaxException ex) {
throw new RuntimeException(ex);
}
}
};
InputSplit[] data = new FileSplit(imageDir, NativeImageLoader.ALLOWED_FORMATS, rng).sample(pathFilter, 0.8, 0.2);
InputSplit trainData = data[0];
InputSplit testData = data[1];
//训练集
ObjectDetectionRecordReader recordReaderTrain = new ObjectDetectionRecordReader(height, width, nChannels, gridHeight, gridWidth,new VocLabelProvider(dataDir));
recordReaderTrain.initialize(trainData);
//测试集
ObjectDetectionRecordReader recordReaderTest = new ObjectDetectionRecordReader(height, width, nChannels, gridHeight, gridWidth,new VocLabelProvider(dataDir));
recordReaderTest.initialize(testData);
//归一化
RecordReaderDataSetIterator train = new RecordReaderDataSetIterator(recordReaderTrain, batchSize, 1, 1, true);
train.setPreProcessor(new ImagePreProcessingScaler(0, 1));
RecordReaderDataSetIterator test = new RecordReaderDataSetIterator(recordReaderTest, 1, 1, 1, true);
test.setPreProcessor(new ImagePreProcessingScaler(0, 1));
//下载预训练模型
ComputationGraph model;
String modelFilename = path+"/model.zip";
if (new File(modelFilename).exists()) {
log.info("load model...");
model = ModelSerializer.restoreComputationGraph(modelFilename);
} else {
log.info("create model...");
//预训练模型
ComputationGraph pretrained = (ComputationGraph)TinyYOLO.builder().build().initPretrained();
INDArray priors = Nd4j.create(priorBoxes);
//修改与训练模型结构
FineTuneConfiguration fineTuneConf = new FineTuneConfiguration
.Builder().seed(seed)
//优化算法:随机梯度下降
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
//梯度标准化算法:RenormalizeL2PerLayer梯度(防止梯度消失和梯度爆炸)
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
.gradientNormalizationThreshold(1.0)
//更新器:Nesterovs
.updater(new Adam.Builder().learningRate(learningRate).build())
.updater(new Nesterovs.Builder().learningRate(learningRate).momentum(lrMomentum).build())
.activation(Activation.IDENTITY)
//内存管理模式:工作区
.trainingWorkspaceMode(WorkspaceMode.ENABLED)
.inferenceWorkspaceMode(WorkspaceMode.ENABLED)
.build();
//迁移学习
model = new TransferLearning
.GraphBuilder(pretrained).
fineTuneConfiguration(fineTuneConf).
removeVertexKeepConnections("conv2d_9")
.addLayer("convolution2d_9",new ConvolutionLayer.Builder(1, 1).nIn(1024).nOut(nBoxes * (5 + nClasses)).stride(1, 1).convolutionMode(ConvolutionMode.Same).weightInit(WeightInit.UNIFORM).hasBias(false).activation(Activation.IDENTITY).build(), "leaky_re_lu_8")
.addLayer("outputs", new Yolo2OutputLayer.Builder().lambbaNoObj(lambdaNoObj).lambdaCoord(lambdaCoord).boundingBoxPriors(priors).build(),"convolution2d_9")
.setOutputs("outputs")
.build();
System.out.println(model.summary(InputType.convolutional(height, width, nChannels)));
log.info("train...");
model.setListeners(new ScoreIterationListener(1));
for (int i = 0; i < nEpochs; i++) {
train.reset();
while (train.hasNext()) {
model.fit(train.next());
}
log.info("*** Completed epoch {} ***", i);
}
//保存模型
ModelSerializer.writeModel(model, modelFilename, true);
//关机
Runtime.getRuntime().exec("shutdown -s -t 30");
}
//模型检测可视化
NativeImageLoader imageLoader = new NativeImageLoader();
CanvasFrame frame = new CanvasFrame("detectionTrain");
OpenCVFrameConverter.ToMat converter = new OpenCVFrameConverter.ToMat();
org.deeplearning4j.nn.layers.objdetect.Yolo2OutputLayer yout = (org.deeplearning4j.nn.layers.objdetect.Yolo2OutputLayer) model.getOutputLayer(0);
List<String> labels = train.getLabels();
test.setCollectMetaData(true);
while (test.hasNext() && frame.isVisible()) {
org.nd4j.linalg.dataset.DataSet ds = test.next();
RecordMetaDataImageURI metadata = (RecordMetaDataImageURI) ds.getExampleMetaData().get(0);
INDArray features = ds.getFeatures();
INDArray results = model.outputSingle(features);
List<DetectedObject> objs = yout.getPredictedObjects(results, detectionThreshold);
File file = new File(metadata.getURI());
log.info(file.getName() + ": " + objs);
Mat mat = imageLoader.asMat(features);
Mat convertedMat = new Mat();
mat.convertTo(convertedMat, CV_8U, 255, 0);
int w = width;
int h = height;
Mat image = new Mat();
resize(convertedMat, image, new Size(w, h));
for (DetectedObject obj : objs) {
double[] xy1 = obj.getTopLeftXY();
double[] xy2 = obj.getBottomRightXY();
String label = labels.get(obj.getPredictedClass());
int x1 = (int) Math.round(w * xy1[0] / gridWidth);
int y1 = (int) Math.round(h * xy1[1] / gridHeight);
int x2 = (int) Math.round(w * xy2[0] / gridWidth);
int y2 = (int) Math.round(h * xy2[1] / gridHeight);
rectangle(image, new Point(x1, y1), new Point(x2, y2), Scalar.RED);
putText(image, label, new Point(x1-80, y2+30), FONT_HERSHEY_DUPLEX, 1, Scalar.RED);
}
frame.setTitle(new File(metadata.getURI()).getName() + " - detectionTrain");
frame.setCanvasSize(w, h);
frame.showImage(converter.convert(image));
frame.waitKey();
}
frame.dispose();
}
public static void main(String[] args) throws Exception {
train();
}
}