scikit-learn一般实例之八:多标签分类

• 选取标签的数目:泊松(n~Poisson,n_labels)
• n次,选取类别C:多项式(c~Multinomial,theta)
• 选取文档长度:泊松(k~Poisson,length)
• k次,选取一个单词:多项式(w~Multinomial,theta_c)

# coding:utf-8

import numpy as np
from pylab import *

from sklearn.datasets import make_multilabel_classification
from sklearn.multiclass import OneVsRestClassifier
from sklearn.svm import SVC
from sklearn.preprocessing import LabelBinarizer
from sklearn.decomposition import PCA
from sklearn.cross_decomposition import CCA

myfont = matplotlib.font_manager.FontProperties(fname="Microsoft-Yahei-UI-Light.ttc")
mpl.rcParams['axes.unicode_minus'] = False

def plot_hyperplane(clf, min_x, max_x, linestyle, label):
# 获得分割超平面
w = clf.coef_[0]
a = -w[0] / w[1]
xx = np.linspace(min_x - 5, max_x + 5)  # 确保线足够长
yy = a * xx - (clf.intercept_[0]) / w[1]
plt.plot(xx, yy, linestyle, label=label)

def plot_subfigure(X, Y, subplot, title, transform):
if transform == "pca":
X = PCA(n_components=2).fit_transform(X)
elif transform == "cca":
X = CCA(n_components=2).fit(X, Y).transform(X)
else:
raise ValueError

min_x = np.min(X[:, 0])
max_x = np.max(X[:, 0])

min_y = np.min(X[:, 1])
max_y = np.max(X[:, 1])

classif = OneVsRestClassifier(SVC(kernel='linear'))
classif.fit(X, Y)

plt.subplot(2, 2, subplot)
plt.title(title,fontproperties=myfont)

zero_class = np.where(Y[:, 0])
one_class = np.where(Y[:, 1])
plt.scatter(X[:, 0], X[:, 1], s=40, c='gray')
plt.scatter(X[zero_class, 0], X[zero_class, 1], s=160, edgecolors='b',
facecolors='none', linewidths=2, label=u'类别-1')
plt.scatter(X[one_class, 0], X[one_class, 1], s=80, edgecolors='orange',
facecolors='none', linewidths=2, label=u'类别-2')

plot_hyperplane(classif.estimators_[0], min_x, max_x, 'k--',
u'类别-1的\n边界')
plot_hyperplane(classif.estimators_[1], min_x, max_x, 'k-.',
u'类别-2的\n边界')
plt.xticks(())
plt.yticks(())

plt.xlim(min_x - .5 * max_x, max_x + .5 * max_x)
plt.ylim(min_y - .5 * max_y, max_y + .5 * max_y)
if subplot == 2:
plt.xlabel(u'第一主成分',fontproperties=myfont)
plt.ylabel(u'第二主成分',fontproperties=myfont)
plt.legend(loc="upper left",prop=myfont)

plt.figure(figsize=(8, 6))

X, Y = make_multilabel_classification(n_classes=2, n_labels=1,
allow_unlabeled=True,
random_state=1)

plot_subfigure(X, Y, 1, u"有无标签样例 + CCA", "cca")
plot_subfigure(X, Y, 2, u"有无标签样例 + PCA", "pca")

X, Y = make_multilabel_classification(n_classes=2, n_labels=1,
allow_unlabeled=False,
random_state=1)

plot_subfigure(X, Y, 3, u"没有无标签样例 + CCA", "cca")
plot_subfigure(X, Y, 4, u"没有无标签样例 + PCA", "pca")

plt.subplots_adjust(.04, .02, .97, .94, .09, .2)
plt.suptitle(u"多标签分类", size=20,fontproperties=myfont)
plt.show()


02-05 968
02-24 311

01-19 5万+
11-26 5534
02-24 1107
04-23 1146
04-26 370