样本准备
准备三个文件夹 postdata 正样本 negdata 负样本 testdata 测试样本
所有样本 统一处理为灰度,尺寸 140*30
(文末提供一些样本,需要自取)
postdata: 正样本
negdata:负样本
testdata:测试样本,车牌图片命名开头为p,非车牌开头n 用于读取标签计算准确率
代码:
import cv2
import os
import numpy as np
import time
# 所有样本 统一处理为灰度,尺寸 140*30
def imgToVector(img):
h,w=img.shape[:2]
Vect = np.zeros(h * w)
for i in range(w):
for j in range(h):
Vect[h * i + j] = img[j][i]
return Vect
'''
post_path , neg_path: 正负样本存储目录
save_model:模型存储路径
'''
def train(post_path,neg_path,save_model):
t1=time.time()
'''
训练集
'''
# 图片均以灰度形式读取。 将读取的图片二维矩阵转为一维向量存入训练集 trainningMat
# 得到训练集尺寸: (样本数 , 4200) 数据类型:float32
trainningMat = [] # 训练集
pList, nList = os.listdir(post_path), os.listdir(neg_path)
pNum, nNum = len(pList), len(nList)
# 加入正样本
for file in pList:
src = cv2.imread(os.path.join(post_path, file),0)
trainningMat.append(imgToVector(src))
# 加入负样本
for file in nList:
src = cv2.imread(os.path.join(neg_path, file),0)
trainningMat.append(imgToVector(src))
trainningMat = np.array(trainningMat, dtype='float32')
'''
样本标签
'''
# 得到 label列表 尺寸: (pNum+nNum , 1)
# 前pNum个样本(正样本)标签为1,后面的(负样本)标签为-1
Labels = np.zeros((pNum+nNum, 1), np.int32)
Labels[:pNum]=1
Labels[pNum:]=-1
'''
模型训练
'''
svm = cv2.ml.SVM_create() # 创建SVM model
# 属性设置
svm.setType(cv2.ml.SVM_C_SVC)
svm.setKernel(cv2.ml.SVM_LINEAR)
svm.setC(0.01)
# 训练
svm.train(trainningMat, cv2.ml.ROW_SAMPLE, Labels)
svm.save(save_model)
t2 = time.time()
print("train_time",t2-t1) #训练用时
'''
model: 模型存储路径
test_dir: 测试样本路径
'''
def test(model,test_dir):
test_label = [] # 测试样本标签
testMat=[] # 测试集
correct_count = 0 # 正确数
files = os.listdir(test_dir)
test_num=len(files)
for file in files:
src=cv2.imread(os.path.join(test_dir,file),0)
test_label.append(file[0]) # 存入标签
testMat.append(imgToVector(src))
testMat = np.array(testMat, dtype='float32')
svm2 = cv2.ml.SVM_load(model) #加载模型
(par1, par2) = svm2.predict(testMat) #预测
for i in range(test_num):
if (par2[i][0] == 1 and test_label[i] == 'p') or (par2[i][0] == -1 and test_label[i] == 'n'):
correct_count += 1
accuracy = correct_count / test_num
print("accuracy:", accuracy) #准确率
if __name__ == '__main__':
train("postdata","negdata","svm.mat") #训练
test("svm.mat","testdata") #测试
百度了一些车辆图片截了车牌样本:
链接:https://pan.baidu.com/s/1cGwRwy_QsxdEjjg4zGh3cg
提取码:ogtf
给个赞么~?