(版本为3.4.0)
opencv官方api文档:https://docs.opencv.org/
ml模块的svm操作:
python版本
1、生成训练数据
训练文件分别以类标签为文件名,里面存放对应的类文件
def generate_data(self,file_dir):
train_data= []
train_labels = []
if os.path.exists(file_dir):
file_list = os.listdir(file_dir)
for fl in file_list:
class_dir = os.path.join(file_dir,fl)
if os.path.isdir(class_dir):
filenames = os.listdir(class_dir)
for f in filenames:
img_name = os.path.join(class_dir,f)
img = cv2.imread(img_name)
img = cv2.resize(img,self.resize,interpolation=cv2.INTER_CUBIC)
new_img = img.reshape((1,self.resize[0]*self.resize[1]*3))
train_data.append(new_img[0])
train_labels .append(int(fl))
return (train_data,train_labels )
2、训练
def svmtrain(train_data,train_labels):
# 创建分类器
svm = cv2.ml.SVM_create()
svm.setType(cv2.ml.SVM_C_SVC) # SVM类型
svm.setKernel(cv2.ml.SVM_LINEAR) # 使用线性核
svm.setC(1.0)
train = np.array(train_data,np.float32)
train_labels = np.array(train_labels,np.int32)
train_labels = train_labels.reshape((train_labels.size,1))
# 训练
ret = svm.train(train, cv2.ml.ROW_SAMPLE, train_labels)
svm.save("svm_data.dat")
3、测试
def svmtest(model_path,test_file,resize):
svm = cv2.ml.SVM_load(model_path)
test_data = []
img = cv2.imread(test_file)
img = cv2.res