1、什么是多分类?
针对多类问题的分类中,具体讲有两种,即multiclass classification和multilabel classification。multiclass是指分类任务中包含不止一个类别时,每条数据仅仅对应其中一个类别,不会对应多个类别。multilabel是指分类任务中不止一个分类时,每条数据可能对应不止一个类别标签,例如一条新闻,可以被划分到多个板块。
无论是multiclass,还是multilabel,做分类时都有两种策略,一个是one-vs-the-rest(one-vs-all),一个是one-vs-one。
在one-vs-all策略中,假设有n个类别,那么就会建立n个二项分类器,每个分类器针对其中一个类别和剩余类别进行分类。进行预测时,利用这n个二项分类器进行分类,得到数据属于当前类的概率,选择其中概率最大的一个类别作为最终的预测结果。
在one-vs-one策略中,同样假设有n个类别,则会针对两两类别建立二项分类器,得到k=n*(n-1)/2个分类器。对新数据进行分类时,依次使用这k个分类器进行分类,每次分类相当于一次投票,分类结果是哪个就相当于对哪个类投了一票。在使用全部k个分类器进行分类后,相当于进行了k次投票,选择得票最多的那个类作为最终分类结果。
在scikit-learn框架中,分别有sklearn.multiclass.OneVsRestClassifier和sklearn.multiclass.OneVsOneClassifier完成两种策略,使用过程中要指明使用的二项分类器是什么。另外在进行mutillabel分类时,训练数据的类别标签Y应该是一个矩阵,第[i,j]个元素指明了第j个类别标签是否出现在第i个样本数据中。例如,np.array([[1, 0, 0], [0, 1, 1], [0, 0, 0]]),这样的一条数据,指明针对第一条样本数据,类别标签是第0个类,第二条数据,类别标签是第1,第2个类,第三条数据,没有类别标签。有时训练数据中,类别标签Y可能不是这样的可是,而是类似[[2, 3, 4], [2], [0, 1, 3], [0, 1, 2, 3, 4], [0, 1, 2]]这样的格式,每条数据指明了每条样本数据对应的类标号。这就需要将Y转换成矩阵的形式,sklearn.preprocessing.MultiLabelBinarizer提供了这个功能。
2、构建多个二分类器进行分类
使用的数据集是sklearn自带的iris数据集,该数据集总共有三类。
importnumpy as npimportmatplotlib.pyplot as pltfrom sklearn importsvm,datasetsfrom itertools importcyclefrom sklearn importsvm, datasetsfrom sklearn.metrics importroc_curve, aucfrom sklearn.model_selection importtrain_test_splitfrom sklearn.preprocessing importlabel_binarizefrom sklearn.multiclass importOneVsRestClassifierfrom scipy importinterp#导入鸢尾花数据集
iris =datasets.load_iris()
X= iris.data #X.shape==(150, 4)
y = iris.target #y.shape==(150, )
#二进制化输出
y = label_binarize(y, classes=[0, 1, 2]) #shape==(150, 3)
n_classes = y.shape[1] #n_classes==3
#np.r_是按列连接两个矩阵,就是把两矩阵上下相加,要求列数相等。#np.c_是按行连接两个矩阵,就是把两矩阵左右相加,要求行数相等。#添加噪音特征,使问题更困难
random_state =np.random.RandomState(0)
n_samples, n_features= X.shape #n_samples==150, n_features==4
X = np.c_[X, random_state.randn(n_samples, 200 * n_features)] #shape==(150, 84)
#打乱数据集并切分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=.5,
random_state=0)#X_train.shape==(75, 804), X_test.shape==(75, 804), y_train.shape==(75, 3), y_test.shape==(75, 3)
#学习区分某个类与其他的类