SVM分类器又被称为支持向量机,因为SVM的原理就是在向量空间建立超平面对数据进行分类。具体的原理这里不做过多介绍,直接贴代码,代码的主要作用是对彩色图像进行二分类。
import os
import cv2 as cv
import numpy as np
# 创建空列表,用来存放训练样本数据
train_mat1 = []
for root, dirs, files in os.walk("E:/data/dianzuhan/2021-8-20/data/L/train/5/OK/", -1):
positive_num = len(files)
print(positive_num)
for image in files:
# 灰度图像读取,并存放在一行。
pso_img = cv.imread("E:/data/dianzuhan/2021-8-20/data/L/train/5/OK/" + image)
img = pso_img.swapaxes(1, 2).swapaxes(0, 1)
img = cv.medianBlur(img, 5)
# img = cv.adaptiveThreshold(img, 255, cv.ADAPTIVE_THRESH_MEAN_C, cv.THRESH_BINARY, 25, 5)
c,h,w=img.shape
Vect = np.zeros(w * h*c)
# 先行后列
for k in range(c):
for i in range(h):
for j in range(w):
Vect[k * w * h + w * i + j] = img[k][i][j]
train_mat1.append(Vect)
# 负样本
train_mat2=[]
for root, dirs, files in os.walk("E:/data/dianzuhan/2021-8-20/data/L/train/5/NG/", -1):
negative_num = len(files)
print(negative_num)
for image in files:
neg_img = cv.imread("E:/data/dianzuhan/2021-8-20/data/L/train/5/NG/" + image)
img = neg_img.swapaxes(1, 2).swapaxes(0, 1)
img = cv.medianBlur(img, 5)
# img = cv.adaptiveThreshold(img, 255, cv.ADAPTIVE_THRESH_MEAN_C, cv.THRESH_BINARY,25, 5)
c,h,w=img.shape
Vect = np.zeros(w * h*3)
# 先行后列
for k in range(c):
for i in range(h):
for j in range(w):
Vect[k*w*h+w * i + j] = img[k][i][j]
train_mat2.append(Vect)
#扩充数据
train_mat=train_mat1+train_mat2
# 将样本从列表转换为数组
train_mat = np.array(train_mat, dtype='float32')
labels_num = np.zeros((train_num), np.int32)
print('1',labels_num)
for i in range(train_num):
if i < positive_num:
labels_num[i] = 1
else:
labels_num[i] = 0
# 调用机器学习模块,创建SVM模型分类器
svm = cv.ml.SVM_create()
# SVM类型
svm.setType(cv.ml.SVM_C_SVC)
# 线性核函数
svm.setKernel(cv.ml.SVM_LINEAR)
svm.setC(0.8)
# 开始训练(数据,类型,标签)
svm.train(train_mat, cv.ml.ROW_SAMPLE, labels_num)
svm.save('l5.mat')
测试部分代码
import cv2 as cv
import numpy as np
import os
import shutil
PATH='E:/data/dianzuhan/2021-8-20/data/L/test/3/OK/'
path='E:/data/dianzuhan/2021-8-20/data/L/test/3/CW/OK/'
svm = cv.ml.SVM_load('l3.mat')
a=0
for root, dirs, files in os.walk(PATH, -1):
test_num = len(files)
# print(files)
for image in files:
test_mat = []
test_img = cv.imread(PATH + image)
img = test_img.swapaxes(1, 2).swapaxes(0, 1) #转换图像的shape顺序将hwc转换成chw
img = cv.medianBlur(img, 5)
c,h,w=img.shape
Vect = np.zeros(w * h*c)
# 先行后列
for k in range(3):
for i in range(h):
for j in range(w):
Vect[k * w * h + w * i + j] = img[k][i][j]
test_mat.append(Vect)
test_mat = np.array(test_mat, dtype='float32')
# SVM预测
(P1, P2) = svm.predict(test_mat)
print('12345',P1,P2)
if P2==0:
a+=1
print(image+' is'+' False')
shutil.copyfile(PATH+image,path+image)#将预判错误的照片复制到一个目录中便于查看具体原因。
print(a,test_num)