opencv系列之机器学习(ml)

(版本为3.4.0)

opencv官方api文档:https://docs.opencv.org/

ml模块的svm操作:

python版本

1、生成训练数据

训练文件分别以类标签为文件名,里面存放对应的类文件

def generate_data(self,file_dir):
    train_data= []
    train_labels = []
    if os.path.exists(file_dir):
        file_list = os.listdir(file_dir)
        for fl in file_list:
            class_dir = os.path.join(file_dir,fl)
            if os.path.isdir(class_dir):
                filenames = os.listdir(class_dir)
                for f in filenames: 
                img_name = os.path.join(class_dir,f)
                img = cv2.imread(img_name)
                img = cv2.resize(img,self.resize,interpolation=cv2.INTER_CUBIC)
                new_img = img.reshape((1,self.resize[0]*self.resize[1]*3))
                train_data.append(new_img[0])
                train_labels .append(int(fl))
    return (train_data,train_labels )
    

2、训练

def svmtrain(train_data,train_labels):
        # 创建分类器  
        svm = cv2.ml.SVM_create()  
        svm.setType(cv2.ml.SVM_C_SVC)  # SVM类型  
        svm.setKernel(cv2.ml.SVM_LINEAR) # 使用线性核  
        svm.setC(1.0)  
        train = np.array(train_data,np.float32)
        train_labels = np.array(train_labels,np.int32)
        train_labels = train_labels.reshape((train_labels.size,1))
        # 训练  
        ret = svm.train(train, cv2.ml.ROW_SAMPLE, train_labels) 
        svm.save("svm_data.dat")

3、测试

 def svmtest(model_path,test_file,resize):
        svm = cv2.ml.SVM_load(model_path)
        test_data = []
        img = cv2.imread(test_file)
        img = cv2.res
  • 4
    点赞
  • 23
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值