一、Multiclass classification
打比方你想要把XX归到n_classes
(即彼此不同&互斥的classes)。一共有4个classes
,分明是"Python"、“Java”、"C++“和"Other language”;此刻我们给定有6个XX,各自的classes表示为y
:
import numpy as np
y = np.asarray(['Java', 'C++', 'Other language', 'Python', 'C++', 'Python'])
上述即Multiclass classification
(又叫multimomial classification
)。
为了用模型拟合&用scikit-learn做验证,你需要把文本的class标签转换成数字格式。LableEncoder函数即可
from sklearn.preprocessing import LabelEncoder
le = LabelEncoder()
y_numeric = le.fit_transform(y)
如下是转换后的数字格式的y:
print(y_numeric)
# 结果如下
array([1, 0, 2, 3, 0, 3], dtype=int64)
这些数字的含义是:
print(le.classes_)
# 结果如下
array(['C++', 'Java', 'Other language', 'Python'], dtype='<U14')
当y只有2个label的时候,我们就叫它 binary classfication/二元分类
二、Multilabel classification
这次假如你需要用 n
个binary classifier做这种multiclasses classification(每个分类器负责一个class预测)。此刻0到n_classes - 1
构成的一维数组就无法表示n
个class labels了,我们要用2维数组,每列表示一个label的binary classifier,这里每行表示一个XX(需要注意的是每个classes仍然是不同&互斥的,否则一行里边可能出现多个1
)。用MultiLabelBinarizer来实现:
from sklearn.preprocessing import MultiLabelBinarizer
mlb = MultiLabelBinarizer()
y_indicator = mlb.fit_transform(y[:,None])
1)关于
y[:,None]
: 作用是把1维的y
变成n*1的2维数组。具体原理笔者暂时没有搞明白;
2)此处也可以直接用OneHotEncoder
,结果一模一样;
得到如下的y_indicator:
print(y_indicator)
# 结果如下
array([[0, 1, 0, 0],
[1, 0, 0, 0],
[0, 0, 1, 0],
[0, 0, 0, 1],
[1, 0, 0, 0],
[0, 0, 0, 1]])
每一行的1
即这4个class labels的布尔表示:
print(mlb.classes_)
# 结果如下
array(['C++', 'Java', 'Other language', 'Python'], dtype=object)
三、Multioutput classification(译文有删改)
这里假如你想同时分2个类,例如编程语言
和应用领域
——这就是multioutput classfication
。为了让例子简单易懂,我用的 classes仅3个应用领域
:分别是"Computer Vision"、“Speech Recognition”、“Other application”。我们的y就完全是另一种2维数组了:
y2 = np.asarray([['Java', 'Computer Vision'],
['C++', 'Speech Recognition'],
['Other language', 'Computer Vision'],
['C++', 'Speech Recognition'],
['Python', 'Other application']])
用OneHotEncoder来实现
from sklearn.preprocessing import OneHotEncoder
ohe = OneHotEncoder()
y_ohe = ohe.fit_transform(y2)
得到如下的y_ohe:
print(y_ohe.toarray())
# 结果如下
array([[0., 1., 0., 0., 1., 0., 0.],
[1., 0., 0., 0., 0., 0., 1.],
[0., 0., 1., 0., 1., 0., 0.],
[1., 0., 0., 0., 0., 0., 1.],
[0., 0., 0., 1., 0., 1., 0.]])
每一行分别用1
表示这2类标签(每一行的前4为表示编程语言,后3位表示编程应用):
print(ohe.categories_)
# 结果如下
[array(['C++', 'Java', 'Other language', 'Python'], dtype='<U18'),
array(['Computer Vision', 'Computer version', 'Speech Recognition'],
dtype='<U18')]