1.思路
使用sklearn库,分别使用logistic回归、线性svm和非线性svm方法实现对于mnist数据集的多分类问题。这里直接调用使用的是二分类器一对多的方法。
2.代码
import numpy as np
import pandas as pd
from sklearn.datasets import fetch_mldata
from sklearn.preprocessing import OneHotEncoder
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.pipeline import Pipeline
from sklearn.linear_model import SGDClassifier
from sklearn.svm import LinearSVC
from sklearn.svm import SVC
def load_data():
mnist = fetch_mldata("MNIST original")
return mnist
def get_data(mnist):
data = mnist["data"]
label_temp = mnist["target"]
ss = StratifiedShuffleSplit(n_splits=10, test_size=0.2, train_size=0.8, random_state=42)
for train_index,test_index in ss.split(data,label_temp):
train_set,train_label = data[train_index],label_temp[train_index]
test_set,test_label = data[test_index],label_temp[test_index]
# print(train_label_temp.reshape(-1,1))
# encoder = OneHotEncoder()
# train_label = encoder.fit_transform(train_label_temp.reshape(-1,1)).toarray()
# test_label = encoder.fit_transform(test_label_temp.reshape(-1,1)).toarray()
return train_set,train_label,test_set,test_label
def logistic_classification(train_set, train_label):
log_clf = Pipeline([("scalar",StandardScaler()),("logistic_clf",SGDClassifier(loss = "log"))])
log_clf.fit(train_set, train_label)
return log_clf
def lin_svm_classification(train_set, train_label):
svm = Pipeline([("scalar",StandardScaler()),("lin_svm_clf",LinearSVC(loss = "hinge",C=1))])
svm.fit(train_set, train_label)
return svm
def nonlin_svm_classifier(train_set, train_label):
svm = Pipeline([("scalar",StandardScaler()),("nonlin_svm",SVC(kernel = "rbf",C=1,gamma = 5))])
svm.fit(train_set, train_label)
return svm
def accuracy(clf,test_data,test_label):
y_predict = clf.predict(test_data)
accuracy = sum(y_predict == test_label) / len(test_label)
return accuracy
if __name__ == "__main__":
# 加载mnist
mnist = load_data()
# 划分数据集
train_set, train_label, test_set, test_label = get_data(mnist)
print(train_label)
# 训练逻辑斯特分类器
log_clf = logistic_classification(train_set, train_label)
# 逻辑斯特准确率
accuracy_log = accuracy(log_clf,test_set, test_label)
print("the accuracy of logistic")
print(accuracy_log)
# 训练线性svm分类器
lin_svm_clf = lin_svm_classification(train_set, train_label)
# 线性svm准确率
accuracy_linsvm = accuracy(lin_svm_clf, test_set, test_label)
print("the accuracy of lin_svm")
print(accuracy_linsvm)
# 训练gauss核非线性svm分类器
non_lin_svm_clf = nonlin_svm_classifier(train_set, train_label)
# 非线性svm准确率
accuracy_nonlinsvm = accuracy(non_lin_svm_clf, test_set, test_label)
print("the accuracy of non_lin_svm")
print(accuracy_nonlinsvm)
3.结果
the accuracy of logistic
0.8947857142857143
/home/sun/my_python3/lib/python3.7/site-packages/sklearn/svm/base.py:929: ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.
"the number of iterations.", ConvergenceWarning)
the accuracy of lin_svm
0.91
Process finished with exit code 137 (interrupted by signal 9: SIGKILL)