sklearn.multioutput 可以处理多输出 (multi-output) 的分类
一个例子就是预测图片每一个像素(标签) 的像素值是多少 (从 0 到 255 的 256 个类别)
Multioutput 估计器有两个:
-
MultiOutputRegressor
: 多输出回归 -
MultiOutputClassifier
: 多输出分类
MultiOutputClassifier:
-
标签 1 - 小于等于 4,4 和 7 之间,大于等于 7 (三类)
-
标签 2 - 数字本身 (十类)
from sklearn.multioutput import MultiOutputClassifier
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
import numpy as np
digits = load_digits()
X_train, X_test, y_train, y_test= train_test_split( digits['data'], digits['target'], test_size=0.2 )
y_train_1st=y_train.copy()
y_train_1st[y_train<=4]=0
y_train_1st[np.logical_and(y_train>4,y_train<7)]=1
y_train_1st[y_train>=7]=2
y_train_multioutput=np.c_[y_train_1st,y_train]
print(y_train_multioutput)
#用含有 100 棵决策树的随机森林来解决这个多输入分类问题。
Mo=MultiOutputClassifier(RandomForestClassifier(n_estimators=100))
Mo.fit(X_train,y_train_multioutput)
#这个 ndarray 第一列是标签 1 的类别,第二列是标签 2 的类别。
print(Mo.predict(X_test[:5,:]))
测试结果:
F:\开发工具\pythonProject\tools\venv\Scripts\python.exe F:/开发工具/pythonProject/tools/python的sklear学习/sklearn04.py
[[0 3]
[1 5]
[1 5]
...
[2 9]
[2 9]
[2 7]]
[[0 3]
[0 0]
[0 0]
[0 0]
[2 9]]
Process finished with exit code 0