import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn import tree
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
iris = datasets.load_iris()
X = iris.data[:, 2:]
y = iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=1, stratify=y)
clf_tree = DecisionTreeClassifier(criterion="gini", max_depth=4, random_state=1)
clf_tree.fit(X_train, y_train)
from mlxtend.plotting import plot_decision_regions
X_combined = np.vstack((X_train,X_test))
y_conbined = np.hstack((y_train, y_test))
fig, ax = plt.subplots(figsize=(7, 7))
plot_decision_regions(X_combined, y_conbined, clf=clf_tree)
plt.xlabel('petal length [cm]')
plt.ylabel('petal width [cm]')
plt.legend(loc='upper left')
plt.tight_layout()
plt.show()
fig, ax = plt.subplots(figsize=(10, 10))
tree.plot_tree(clf_tree, fontsize=10)
plt.show()