(SVM简单的编写)
编辑器环境:VsCode , Pycharm
所需要的文件:
- 正样本文件件(存放正样本图片-需要train 的图片)
- 负样本文件夹(存放负样本的图片)
- 测试文件夹(测试图片的存放)
上面的图片是从微信导入。
训练的图片最好使用灰度后的二值化图像,更加好的效果
将正负样本放到同一个 samples 和labels里面,使用SVM train 去训练图片。
注意正负样本的大小需要一致。
import os
import cv2
import numpy as np
def date_load():
pwd = os.getcwd()
print(pwd)
pos_dir = os.path.join(pwd,'positive')
if os.path.exists(pos_dir):
pos = os.listdir(pos_dir)
neg_dir = os.path.join(pwd, 'negative')
if os.path.exists(neg_dir):
neg = os.listdir(neg_dir)
samples = []
labels = []
# 处理正样品
for f in pos:
file_path = os.path.join(pos_dir,f)
if os.path.exists(file_path):
pos_img = cv2.imread(file_path)
pos_img = cv2.resize(pos_img,(160,320))
descriptors = np.resize(pos_img,(1, 160*320*3))
samples.append(descriptors)
labels.append(1.)
# 处理正样品
for f in neg:
file_path = os.path.join(neg_dir, f)
if os.path.exists(file_path):
neg_img = cv2.imread(file_path)
neg_img = cv2.resize(neg_img, (160, 320))
descriptors = np.resize(neg_img, (1, 160*320*3))
samples.append(descriptors)
labels.append(-1.)
samples_number = len(samples)
samples = np.float32(samples)
samples = np.resize(samples, (samples_number, 160*320*3))
labels = np.int32(labels)
labels = np.resize(labels, (samples_number, 1))
return samples, labels
def train_svm(samples, labels):
svm = cv2.ml.SVM_create()
svm.setKernel(cv2.ml.SVM_LINEAR)
svm.setType(cv2.ml.SVM_EPS_SVR)
svm.setP(0.1)
criterial = (cv2.TERM_CRITERIA_MAX_ITER + cv2.TERM_CRITERIA_EPS, 5000, 1e-16)
svm.setTermCriteria(criterial)
svm.train(samples, cv2.ml.ROW_SAMPLE, labels)
wT = svm.getSupportVectors()
rho,_,_ = svm.getDecisionFunction(0)
b = -rho
return wT, b
if __name__ == '__main__':
samples, labels = date_load()
wT, b = train_svm(samples, labels)
img_num = os.path.join(os.getcwd(), 'test')
video = cv2.VideoCapture(0)
while True:
ret, frame = video.read()
src = frame.copy()
src = cv2.resize(src, (160, 320))
x = np.resize(src, (160 * 320*3, 1))
value = np.dot(wT, x)[0][0] + b
cv2.imshow('src', src)
if value > 0:
cv2.putText(frame, 'BALL', (10, 55), 2, 1, (0, 255, 0), 5, cv2.LINE_AA)
cv2.imshow('img', frame)
cv2.waitKey(1)
else:
cv2.putText(frame, 'NO BALL', (10, 55), 2, 1, (0, 0, 255), 5, cv2.LINE_AA)
cv2.imshow('img', frame)
cv2.waitKey(1)
if cv2.waitKey(1) == 'q':
break
# for i in range(12):
# img = cv2.imread('test/' + str(i) + '.jpeg')
# img = cv2.resize(img, (160, 320))
# x = np.resize(img, (160 * 320*3, 1))
# value = np.dot(wT, x)[0][0] + b
# print(value)
# if value > 0:
# cv2.putText(img, 'GOOD', (10, 55), 2, 1, (0, 255, 0), 5, cv2.LINE_AA)
# cv2.imshow('img', img)
#
# cv2.waitKey(0)
# else:
# cv2.putText(img, 'BAD', (10, 55), 2, 1, (0, 0, 255), 5, cv2.LINE_AA)
# cv2.imshow('img', img)
# cv2.waitKey(0)
将读取摄像头的片段注释打开下面的图片读取就可以实现图片的读取测试
for i in range(12):
img = cv2.imread('test/' + str(i) + '.jpeg')
img = cv2.resize(img, (160, 320))
x = np.resize(img, (160 * 320*3, 1))
value = np.dot(wT, x)[0][0] + b
print(value)
if value > 0:
cv2.putText(img, 'GOOD', (10, 55), 2, 1, (0, 255, 0), 5, cv2.LINE_AA)
cv2.imshow('img', img)
cv2.waitKey(0)
else:
cv2.putText(img, 'BAD', (10, 55), 2, 1, (0, 0, 255), 5, cv2.LINE_AA)
cv2.imshow('img', img)
cv2.waitKey(0)
其中:
更改下列的 5000 大小就可以更改后面测试样本的逼近超平面的测试逼近值
criterial = (cv2.TERM_CRITERIA_MAX_ITER + cv2.TERM_CRITERIA_EPS, 5000, 1e-16)
train_svm()函数为使用 SVM 分类器读取类和标记值。将输出的结果打印在图片上
本程序只针对整张图片的识别检测,在此基础上可以扩展成为具体的物体识别检测。