上一节三节讲述了真实数据(csv表格数据)的一个实战操作的总流程,然而这个处理是一个回归模型,即目标是一些连续的值(median_house_value)。当目标是一些有限的离散值得时候(比如数字0-9),就变成了分类问题,下面开始讲述分类问题。
四、分类问题
下面将使用新的具有代表性的数据集MNIST(手写体数字数据集),数据集总共有70000个小图片,每个小图片为一个手写的数字,(数据中0代表白,1代表黑),数据中把28*28个像素拉成一个向量作为特征,写的数字作为label。
1、关于MNIST数据集
Scikit-learn提供了MNIST数据的下载,如果下载不了也可以自行网站上下载。
from sklearn.datasets import fetch_mldata
mnist = fetch_mldata('MNIST original')
下载完成后可以输入mnist自行查看一下数据的结构,还可以使用matplotlib输出一张图片看看。
下面需要划分训练集和测试集,MNIST数据集已经帮我们划分好(前60000个为训练集,后10000个位测试集)
X, y = mnist["data"], mnist["target"]
X_train, X_test, y_train, y_test = X[:60000], X[60000:], y[:60000], y[60000:]
虽然MNIST数据集中已经把训练测试集分好,但是还未打乱(shuffle),所以需要对训练集进行打乱。
import numpy as np
shuffle_index = np.random.permutation(60000)
X_train, y_train = X_train[shuffle_index], y_train[shuffle_index]
2、二分类
假设现在分类是否为数字5,则分类两类(是5或不是5),训练一个SGD分类器(该分类器对大规模的数据处理较快)。
#划分数据
y_train_5 = (y_train == 5)
y_test_5 = (y_test == 5)
#训练模型
from sklearn.linear_model import SGDClassifier
sgd_clf = SGDClassifier(random_state=42)
sgd_clf.fit(X_train, y_train_5)
#交叉验证
from sklearn.model_selection import cross_val_predict
y_train_pred = cross_val_predict(sgd_clf, X_train, y_train_5, cv=3)
查准率和查全率(Precision and Recall)以及F1指标
与回归问题计算损失函数不同,二分类特有的一种评价指标为查准率和查全率(Precision and Recall)以及F1指标。
Precision就是预测为正类的样本有多少比例的样本是真的正类,TP/(TP+FP);Recall就是所有真正的正类样本有多少比例被预测为正类,TP/(TP+FN)。其中TP为真正类被预测为正类,FP为负类被预测为正类,FN为真正类被预测为负类。Scikit-learn也有对应的函数
from sklearn.metrics import precision_score, recall_score
precision_score(y_train_5, y_train_pred)
recall_score(y_train_5, y_train_pred)
由于Precision和Recall有两个数,如果一大一下的话不好比较两个模型的好坏,F1指标就是结合两者,求调和平均