diabetes_data_upload.csv
使用了sklearn库中决策树tree.DecisionTreeClassifier()函数
默认是使用gini index作为impurity measures
import pandas as pd from sklearn import tree df = pd.read_csv("diabetes_data_upload.csv") X = pd.get_dummies(df.drop(columns="class")) #Convert categorical attributes to binary attributes using get_dummies() y = df["class"] dtc = tree.DecisionTreeClassifier().fit(X, y) print(tree.export_text(dtc, feature_names=X.columns.tolist()))
可视化决策树
import pandas as pd from sklearn import tree import graphviz df = pd.read_csv("diabetes_data_upload.csv") X = pd.get_dummies(df.drop(columns="class")) y = df["class"] dtc = tree.DecisionTreeClassifier().fit(X, y) dot_data = tree.export_graphviz(dtc, out_file=None) graph = graphviz.Source(dot_data) graph.render("diabetes") #Generate diabetes.pdf
Python实现决策树 Desision Tree & 可视化
最新推荐文章于 2024-08-01 08:51:58 发布