我使用的是python3.4中scikit learn包中的决策树分类器,我希望获得每个输入数据点对应的叶节点id。在
例如,我的输入可能如下所示:array([[ 5.1, 3.5, 1.4, 0.2],
[ 4.9, 3. , 1.4, 0.2],
[ 4.7, 3.2, 1.3, 0.2]])
假设对应的叶节点分别为16、5和45。我希望我的输出是:
^{pr2}$
我已经读完了scikit学习邮件列表和关于SF的相关问题,但是我还是不能让它发挥作用。这是我在邮件列表上找到的一些提示,但仍然不起作用。在
最后,我只想有一个GetLeafNode(clf,X_valida)函数,这样它的输出就是相应叶节点的列表。下面是重现我收到的错误的代码。所以,任何建议都将不胜感激。在from sklearn.datasets import load_iris
from sklearn import tree
# load data and divide it to train and validation
iris = load_iris()
num_train = 100
X_train = iris.data[:num_train,:]
X_valida = iris.data[num_train:,:]
y_train = iris.target[:num_train]
y_valida = iris.target[num_train:]
# fit the decision tree using the train data set
clf = tree.DecisionTreeClassifier()
clf = clf.fit(X_train, y_train)
# Now I want to know the corresponding leaf node id for each of my training data point
clf.tree_.apply(X_train)
# This gives the error message below:
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
in ()
----> 1 clf.tree_.apply(X_train)
_tree.pyx in sklearn.tree._tree.Tree.apply (sklearn/tree/_tree.c:19595)()
ValueError: Buffer dtype mismatch, expected 'DTYPE_t' but got 'double'