使用sklearn,分别运用logisit、线性svm和带高斯核的svm实现mnist多分类

1.思路
使用sklearn库,分别使用logistic回归、线性svm和非线性svm方法实现对于mnist数据集的多分类问题。这里直接调用使用的是二分类器一对多的方法。
2.代码

import numpy as np
import pandas as pd
from sklearn.datasets import fetch_mldata
from sklearn.preprocessing import OneHotEncoder
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.pipeline import  Pipeline
from sklearn.linear_model import SGDClassifier
from sklearn.svm import LinearSVC
from sklearn.svm import SVC

def load_data():
    mnist = fetch_mldata("MNIST original")
    return mnist

def get_data(mnist):
    data = mnist["data"]
    label_temp = mnist["target"]
    ss = StratifiedShuffleSplit(n_splits=10, test_size=0.2, train_size=0.8, random_state=42)
    for train_index,test_index in ss.split(data,label_temp):
        train_set,train_label = data[train_index],label_temp[train_index]
        test_set,test_label = data[test_index],label_temp[test_index]
    # print(train_label_temp.reshape(-1,1))
    # encoder = OneHotEncoder()
    # train_label = encoder.fit_transform(train_label_temp.reshape(-1,1)).toarray()
    # test_label = encoder.fit_transform(test_label_temp.reshape(-1,1)).toarray()
    return train_set,train_label,test_set,test_label

def logistic_classification(train_set, train_label):
    log_clf = Pipeline([("scalar",StandardScaler()),("logistic_clf",SGDClassifier(loss = "log"))])
    log_clf.fit(train_set, train_label)
    return log_clf

def lin_svm_classification(train_set, train_label):
    svm = Pipeline([("scalar",StandardScaler()),("lin_svm_clf",LinearSVC(loss = "hinge",C=1))])
    svm.fit(train_set, train_label)
    return svm

def nonlin_svm_classifier(train_set, train_label):
    svm = Pipeline([("scalar",StandardScaler()),("nonlin_svm",SVC(kernel = "rbf",C=1,gamma = 5))])
    svm.fit(train_set, train_label)
    return svm

def accuracy(clf,test_data,test_label):
    y_predict = clf.predict(test_data)
    accuracy = sum(y_predict == test_label) / len(test_label)
    return accuracy

if __name__ == "__main__":
    # 加载mnist
    mnist = load_data()
    # 划分数据集
    train_set, train_label, test_set, test_label = get_data(mnist)
    print(train_label)
    # 训练逻辑斯特分类器
    log_clf = logistic_classification(train_set, train_label)
    # 逻辑斯特准确率
    accuracy_log = accuracy(log_clf,test_set, test_label)
    print("the accuracy of logistic")
    print(accuracy_log)
    # 训练线性svm分类器
    lin_svm_clf = lin_svm_classification(train_set, train_label)
    # 线性svm准确率
    accuracy_linsvm = accuracy(lin_svm_clf, test_set, test_label)
    print("the accuracy of lin_svm")
    print(accuracy_linsvm)
    # 训练gauss核非线性svm分类器
    non_lin_svm_clf = nonlin_svm_classifier(train_set, train_label)
    # 非线性svm准确率
    accuracy_nonlinsvm = accuracy(non_lin_svm_clf, test_set, test_label)
    print("the accuracy of non_lin_svm")
    print(accuracy_nonlinsvm)

3.结果

the accuracy of logistic
0.8947857142857143
/home/sun/my_python3/lib/python3.7/site-packages/sklearn/svm/base.py:929: ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.
  "the number of iterations.", ConvergenceWarning)
the accuracy of lin_svm
0.91

Process finished with exit code 137 (interrupted by signal 9: SIGKILL)
  • 0
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值