参考链接:
https://blog.csdn.net/gsww404/article/details/80398584
https://blog.csdn.net/lights_joy/article/details/46779343
https://blog.csdn.net/u011489887/article/details/80084338
c++:https://blog.csdn.net/weixin_41275726/article/details/83051192#%E9%A2%84%E6%B5%8B%C2%A0
https://blog.csdn.net/xiake001/article/details/77509023
代码实现:
#1 SVM.SVC---用于分类
import numpy as np
from sklearn import svm
X = np.array([[-1, -1], [-2, -1], [1, 1], [2, 1]])
y = np.array([1, 1, 2, 2])
clt = svm.SVC()
clt.fit(X, y)
print (clt.predict([[5.8, -1]]))
#2 SVM.SVR---用于回归
from sklearn import svm
X = [[0, 0], [2, 2]]
y = [0.5, 2.5]
clf = svm.SVR()
clf.fit(X, y)
print(clf.predict([[1, 1]]))
应用样例:
from numpy import *
import numpy as np
import cv2
import matplotlib.pyplot as plt
import shutil
def loadDataSet(fileName):
dataMat = []
labelMat = []
with open(fileName) as fr:
for line in fr.readlines():
lineArr = line.strip().split(' ')
# print(lineArr)
# dataMat.append([float(lineArr[0]), float(lineArr[1]),
# float(lineArr[2]), float(lineArr[3]), float(lineArr[4])])
dataMat.append([float(lineArr[0]), float(lineArr[1]), float(lineArr[2])])
labelMat.append([int(lineArr[5])])
return dataMat, labelMat
#加载训练集
train_data,train_label = loadDataSet('data3/train/feature-shuffle.txt') #1.加载一个txt数据集
train_data = mat(train_data)
train_data=np.array(train_data, dtype='float32')
train_label = mat(train_label)
# print(train_data)
print(train_label.shape)
#加载测试集
test_data,test_label = loadDataSet('data3/val2/feature.txt') #1.加载一个txt数据集
test_data = mat(test_data)
test_data=np.array(test_data, dtype='float32')
test_label=mat(test_label)
print(test_label.shape)
# 创建分类器
svm = cv2.ml.SVM_create()
svm.setType(cv2.ml.SVM_C_SVC) # SVM类型
svm.setKernel(cv2.ml.SVM_LINEAR) # 使用线性核
svm.setC(1e-5)
# 训练
ret = svm.train(train_data, cv2.ml.ROW_SAMPLE, train_label)
svm.save('data2/train/hand_class.xml')
# 支持向量
vec = svm.getSupportVectors()
print("最终结果:",vec)
# 测试
# svm = cv2.ml.SVM_load("weight/test3/hand_detect.xml")
(ret, res) = svm.predict(test_data)
# print(res)
# # 准确率
# f1=open("data/val2/0/hand_point.txt","r")
# lines=f1.readlines()
# save_path='data/error/val2/'
n=0
lens=len(test_data)
for i in range(lens):
if res[i]==test_label[i]:
n=n+1
# else:
# root=lines[i].split(' ')[23]
# file_name=root.split('/')[-1]
# # print(file_name)
# shutil.copy(root,save_path+file_name+str(test_label[i])+".jpg")
Accuracy=n/lens
print("准确度为:",Accuracy)