python(opencv) SVM 测试使用

(SVM简单的编写)
编辑器环境:VsCode , Pycharm
所需要的文件:

  1. 正样本文件件(存放正样本图片-需要train 的图片)
  2. 负样本文件夹(存放负样本的图片)
  3. 测试文件夹(测试图片的存放)
    在这里插入图片描述
    在这里插入图片描述
    上面的图片是从微信导入。
    训练的图片最好使用灰度后的二值化图像,更加好的效果
    将正负样本放到同一个 samples 和labels里面,使用SVM train 去训练图片。
    注意正负样本的大小需要一致。
import os
import cv2
import numpy as np


def date_load():
    pwd = os.getcwd()
    print(pwd)

    pos_dir = os.path.join(pwd,'positive')
    if os.path.exists(pos_dir):
        pos = os.listdir(pos_dir)

    neg_dir = os.path.join(pwd, 'negative')
    if os.path.exists(neg_dir):
        neg = os.listdir(neg_dir)

    samples = []
    labels = []

    # 处理正样品
    for f in pos:
        file_path = os.path.join(pos_dir,f)
        if os.path.exists(file_path):
            pos_img = cv2.imread(file_path)
            pos_img = cv2.resize(pos_img,(160,320))
            descriptors = np.resize(pos_img,(1, 160*320*3))
            samples.append(descriptors)
            labels.append(1.)

    # 处理正样品
    for f in neg:
        file_path = os.path.join(neg_dir, f)
        if os.path.exists(file_path):
            neg_img = cv2.imread(file_path)
            neg_img = cv2.resize(neg_img, (160, 320))
            descriptors = np.resize(neg_img, (1, 160*320*3))
            samples.append(descriptors)
            labels.append(-1.)

    samples_number = len(samples)
    samples = np.float32(samples)
    samples = np.resize(samples, (samples_number, 160*320*3))

    labels = np.int32(labels)
    labels = np.resize(labels, (samples_number, 1))
    return samples, labels



def train_svm(samples, labels):
    svm = cv2.ml.SVM_create()
    svm.setKernel(cv2.ml.SVM_LINEAR)
    svm.setType(cv2.ml.SVM_EPS_SVR)
    svm.setP(0.1)

    criterial = (cv2.TERM_CRITERIA_MAX_ITER + cv2.TERM_CRITERIA_EPS, 5000, 1e-16)
    svm.setTermCriteria(criterial)
    svm.train(samples, cv2.ml.ROW_SAMPLE, labels)
    wT = svm.getSupportVectors()
    rho,_,_ = svm.getDecisionFunction(0)
    b = -rho
    return wT, b


if __name__ == '__main__':
    samples, labels = date_load()
    wT, b = train_svm(samples, labels)
    img_num = os.path.join(os.getcwd(), 'test')
    video = cv2.VideoCapture(0)

    while True:
        ret, frame = video.read()
        src = frame.copy()
        src = cv2.resize(src, (160, 320))
        x = np.resize(src, (160 * 320*3, 1))
        value = np.dot(wT, x)[0][0] + b
        cv2.imshow('src', src)
        if value > 0:
            cv2.putText(frame, 'BALL', (10, 55), 2, 1, (0, 255, 0), 5, cv2.LINE_AA)
            cv2.imshow('img', frame)
            cv2.waitKey(1)
        else:
            cv2.putText(frame, 'NO BALL', (10, 55), 2, 1, (0, 0, 255), 5, cv2.LINE_AA)
            cv2.imshow('img', frame)
            cv2.waitKey(1)

        if cv2.waitKey(1) == 'q':
            break

    # for i in range(12):
    #     img = cv2.imread('test/' + str(i) + '.jpeg')
    #     img = cv2.resize(img, (160, 320))
    #     x = np.resize(img, (160 * 320*3, 1))
    #     value = np.dot(wT, x)[0][0] + b
    #     print(value)
    #     if value > 0:
    #         cv2.putText(img, 'GOOD', (10, 55), 2, 1, (0, 255, 0), 5, cv2.LINE_AA)
    #         cv2.imshow('img', img)
    #
    #         cv2.waitKey(0)
    #     else:
    #         cv2.putText(img, 'BAD', (10, 55), 2, 1, (0, 0, 255), 5, cv2.LINE_AA)
    #         cv2.imshow('img', img)
    #         cv2.waitKey(0)

将读取摄像头的片段注释打开下面的图片读取就可以实现图片的读取测试

   for i in range(12):
        img = cv2.imread('test/' + str(i) + '.jpeg')
        img = cv2.resize(img, (160, 320))
        x = np.resize(img, (160 * 320*3, 1))
        value = np.dot(wT, x)[0][0] + b
        print(value)
        if value > 0:
            cv2.putText(img, 'GOOD', (10, 55), 2, 1, (0, 255, 0), 5, cv2.LINE_AA)
            cv2.imshow('img', img)

            cv2.waitKey(0)
        else:
            cv2.putText(img, 'BAD', (10, 55), 2, 1, (0, 0, 255), 5, cv2.LINE_AA)
            cv2.imshow('img', img)
            cv2.waitKey(0)

其中:
更改下列的 5000 大小就可以更改后面测试样本的逼近超平面的测试逼近值

    criterial = (cv2.TERM_CRITERIA_MAX_ITER + cv2.TERM_CRITERIA_EPS, 5000, 1e-16)

在这里插入图片描述
在这里插入图片描述

train_svm()函数为使用 SVM 分类器读取类和标记值。将输出的结果打印在图片上
本程序只针对整张图片的识别检测,在此基础上可以扩展成为具体的物体识别检测。

  • 3
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值