基于python sklearn的 SVM支持向量机 类实现

实现SVM

基于python的sklearn机器学习 类实现

平台
python3.7Anacondasklearn库及配套库

代码:

# -*- coding: utf-8 -*-
import numpy as np
import pandas as pd
from sklearn import svm
from sklearn.externals import joblib#保存模型
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from sklearn.metrics import cohen_kappa_score
from sklearn.metrics import classification_report
from sklearn.metrics import confusion_matrix    # 生成混淆矩阵函数
import matplotlib.pyplot as plt
import matplotlib as mpl
import itertools
class mysvm():
    '''
    调用sklearn 实现SVM功能:
    画混淆矩阵
    输入数据实现训练
    保存模型到指定位置
    调用模型实现预测
    '''
    def plot_confusion_matrix(self,cm, classes,
                          normalize=False,
                          title='Confusion matrix',
                          cmap=plt.cm.Blues,path="maxtix"):
        """
        画混淆矩阵
        This function prints and plots the confusion matrix.
        Normalization can be applied by setting `normalize=True`.
        画图函数 输入:
        cm 矩阵 
        classes 输入str类型
        title 名字
        cmap [图的颜色设置](https://matplotlib.org/examples/color/colormaps_reference.html)
        """
        if normalize:
            cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
            print("Normalized confusion matrix")
        else:
            print('Confusion matrix, without normalization')
        print(cm)
        plt.figure(figsize=(11,8))
        plt.imshow(cm, interpolation='nearest', cmap=cmap)
        plt.title(title)
        plt.colorbar()

        tick_marks = np.arange(len(classes))
        plt.xticks(tick_marks, classes, rotation=45)
        plt.yticks(tick_marks, classes)
        fmt = '.2f' if normalize else 'd'
        thresh = cm.max() / 2.
        for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
            plt.text(j, i, format(cm[i, j], fmt),
                    horizontalalignment="center",
                      color="white" if cm[i, j] > thresh else "black")
        # plt.gca().set_xticks(tick_marks, minor=True)
        # plt.gca().set_yticks(tick_marks, minor=True)
        # plt.gca().xaxis.set_ticks_position('none')
        # plt.gca().yaxis.set_ticks_position('none')
        #plt.grid()
        # plt.gcf().subplots_adjust(bottom=0.1)
        # plt.tight_layout()
        plt.ylabel('True label')
        plt.xlabel('Predicted label')
        #解决中文显示
        plt.rcParams['font.sans-serif']=['SimHei']
        plt.rcParams['axes.unicode_minus'] = False    
        plt.savefig(path,dpi=200)  
        # plt.show()
        
    def justdoSVM(self,x,y,path):
        """
        SVM类
        输入:
        x、y以实现训练,path是保存训练过程的路径
        输出:
        clf 模型
        matrix 混淆矩阵
        dd classifi_report
        kappa kappa系数
        acc_1 模型精度
        """
        depthlist=[]
        depth=np.arange(15,50,15)
        for num in depth:
            print(num)
            X_train,data1x,y_train,data1y = train_test_split(x,y,test_size=0.9,random_state=0)
            #clf=svm.SVC(C=1000000+1000000*num, cache_size=200, class_weight=None, coef0=0.0,
            clf=svm.SVC(C=num, cache_size=200, class_weight=None, coef0=0.0,
            decision_function_shape='ovo', degree=3, gamma=5, kernel='rbf',
            max_iter=-1, probability=False, random_state=None, shrinking=True,
            tol=0.001, verbose=False)
            clf.fit(X_train, y_train)
            y_pred_rf = clf.predict(data1x)
            depthlist.append(accuracy_score(data1y,y_pred_rf))
            print(num)
            print(accuracy_score(data1y, y_pred_rf))  #整体精度
            print(cohen_kappa_score(data1y, y_pred_rf))  #Kappa系数
            print('class预测:\n',classification_report(data1y,y_pred_rf))
            matrix=confusion_matrix(data1y, y_pred_rf)
            kappa=cohen_kappa_score(data1y, y_pred_rf)
            dd=classification_report(data1y, y_pred_rf)
            acc_1=accuracy_score(data1y,y_pred_rf)
            # plt.show()
            #return clf,matrix,dd,kappa
        mpl.rcParams['font.sans-serif'] = ['SimHei']
        plt.figure(facecolor='w')#size
        plt.plot(depth, depthlist, 'ro-', lw=1)
        plt.xlabel('SVM中num参数', fontsize=15)
        plt.ylabel('预测精度', fontsize=15)
        plt.title('SVM数量和过拟合', fontsize=18)
        plt.grid(True)
        plt.savefig(path,dpi=300)
        #plt.show()
        print(depthlist.index(max(depthlist)))
        return clf,matrix,dd,kappa,acc_1
    def save_model(self,clf,src):
        """
        保存模型到某处
        clf 模型
        src 路径
        """
        joblib.dump(clf, src)
    
    def get_model_predit(self,data,src):
        """
        调用模型实现预测
        输入原始数据
        src 模型路径
        返回预测值
        """
        getsavemodel=joblib.load(src)
        predity=getsavemodel.predict(pd.DataFrame(data))
        return predity
  • 0
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值