构建决策树时出现ValueError: Length of feature_names, 4 does not match number of features, 10的解决办法

构建决策树时出现ValueError: Length of feature_names, 4 does not match number of features, 10的解决办法

以下为原代码:

import pandas as pd
from sklearn import tree

ball_data = pd.read_csv('ball.csv')
cat_features = ['Outlook','Temperature','Humidity','Windy']
ball_data_onehot = pd.get_dummies(ball_data,columns=cat_features)
ball_data_onehot['Play'] = ball_data_onehot['Play'].map({'No':0,'Yes':1}).astype(int)
ball_data_tr_in = ball_data_onehot.drop(columns=['Play'])
ball_data_tr_out = ball_data_onehot['Play']

from sklearn.model_selection import train_test_split

X_train,X_test,y_train,y_test = train_test_split(ball_data_tr_in,ball_data_tr_out,test_size=0.3,random_state=1)
#random_state=1表示数据集中的测试集和训练集每次运行都随机生成,random_state=1表示除第一次运行数据集中的训练集和测试集是随机生成的以外,之后每次运行的数据集和测试集都和第一次运行一样
print(X_train)
print(y_train)
clf = tree.DecisionTreeClassifier(criterion='entropy')
model = clf.fit(X_train,y_train)
res = model.predict(X_test)
print(res)    #模型结果输出
print(y_test)    #实际值
print(sum(res==y_test)/len(res))    #准确率

#测试DecisionTreeClassifier()不加参数的情况
clf1 = tree.DecisionTreeClassifier()
clf1 = clf.fit(ball_data_tr_in,ball_data_tr_out)
print('clf1:'+str(clf1))    #打印模型,打印出来是一个函数
print(clf1)

#自己构建一个数据A,代入模型对A进行预测
A = ([[0,1,0,0,0,1,0,1,0,1]])
predict_result = clf.predict(A)
print('预测结果:'+str(predict_result))
if predict_result==0:
    print('湿度正常,有点雨,有雾,有风的时候不去打球')

#构建决策树
import graphviz
dot_data = tree.export_graphviz(model,out_file = None,
                                feature_names=['Outlook','Temperature','Humidity','Windy'],
                                class_names = ['No','Yes'],
                                filled = True,rounded = True,
                                special_characters = True)
graph = graphviz.Source(dot_data)
graph.render('PlayBall')

pycharm运行后报错将错误索引到构建决策树时的tree.export_graphviz中的feature_names这一参数上,以下为代码错误部分:

dot_data = tree.export_graphviz(model,out_file = None,
                                feature_names=['Outlook','Temperature','Humidity','Windy'],
                                class_names = ['No','Yes'],
                                filled = True,rounded = True,
                                special_characters = True)

错误就出在feature_names=[‘Outlook’,‘Temperature’,‘Humidity’,‘Windy’] 这个地方,feature_names这一参数的作用是使决策树图中的各个小块可以显示其对应的特征名(自己是初学者描述不够详细,具体可以去搜其他关于export_graphviz()参数的资料),而这一参数要求参数值必须与被分析的数据集中的特征名(也叫属性名)对应,数量一致且顺序一致,且不要把类标号class也写进去(因为类标号是最终分析的结果)。

以上面原代码为例,一开始读取ball.csv数据文件,具体内容如下:

OutlookTemperatureHumidityWindyPlay
sunnyhothighnoNo
sunnyhothighyesNo
overcasthothighnoYes
rainmildhighnoYes
raincoolnormalnoYes
raincoolnormalyesNo
overcastcoolnormalyesYes
sunnymildhighyesNo
sunnycoolnormalnoYes
rainmildnormalnoYes
sunnymildnormalyesYes
overcastmildhighyesYes
overcasthotnormalnoYes
rainmildhighyesNo

可以发现特征名为Outlook,Temperature,Humidity,Windy(且以此顺序),Play是类标号,在原代码一开始我们定义了一个cat_features = [‘Outlook’,‘Temperature’,‘Humidity’,‘Windy’] 的列表用于存放我们所需要进行的独热编码的列名称(即特征名),ball_data_onehot = pd.get_dummies(ball_data,columns=cat_features) 用于对数据进行独热编码,注意一定要理解独热编码的具体含义,在此处非常重要,具体可以去查其他相关资料;一开始提到错误出现在feature_name这里,因为feature_name要求和数据的特征名一致,但是此时观察上表以及cat_features发现不管是特征名还是特征名的顺序都一致,但是为什么会出现Length of feature_names, 4 does not match number of features, 10(feature_name的个数4与数据中特征名的个数10不符) 这样的错误呢?
在这里插入图片描述
其实稍微了解独热编码后会发现,pd.get_dummies() 对数据进行独热编码后,数据的结构会发生很大的改变,在对上表的数据进行独热编码后数据变为:

Outlook_overcastOutlook_rainOutlook_sunnyTemperature_coolTemperature_hotTemperature_mildHumidity_highHumidity_normalWindy_noWindy_yes
0010101010
0010101001
1000101010
0100011010
0101000110
0101000101
1001000101
0010011001
0011000110
0100010110
0010010101
1000011001
1000100110
0100011001

发现经过独热编码后此时特征名已经变成了:Outlook_overcast,Outlook_rain,Outlook_sunny,Temperature_cool,Temperature_hot,Temperature_mild,Humidity_high,Humidity_normal,Windy_no,Windy_yes,恰好是10个,真相大白,
这才是feature_names的正确赋值,此时将feature_names改为:

feature_names=['Outlook_overcast','Outlook_rain','Outlook_sunny','Temperature_cool','Temperature_hot','Temperature_mild','Humidity_high','Humidity_normal','Windy_no','Windy_yes']

即可解决错误,之所以出现此错误,是因为对独热编码的方式不够了解,自己以后引以为戒,第一次写博客写的不好,希望能对遇到相同问题的朋友有所帮助。

已标记关键词 清除标记
©️2020 CSDN 皮肤主题: 大白 设计师:CSDN官方博客 返回首页