# 寻找多数类别
def majorityCnt(classList):
classCount = {}
for vote in classList:
if vote not in classCount.keys():
classCount[vote] = 0
classCount[vote] += 1
sortedclassCount = sorted(
classCount.items(), key=operator.itemgetter(1), reverse=True)
return sortedclassCount[0][0]
# 选择最好特征
def chooesBeatFeatureToSpilt(dataset):
numFeatures = len(dataset[0])-1
baseEntropy = calcShannonEnt(dataset)
bestInfoGain = 0
bestFeature = -1
for i in range(numFeatures):
featList = [example[i] for example in dataset]
uniqueVals = set(featList)
newEntropy = 0
for val in uniqueVals:
subDataSet = splitDataSet(dataset, i, val)
prob = len(subDataSet)/float(len(dataset))
newEntropy += prob*calcShannonEnt(subDataSet)
InfoGain = baseEntropy-newEntropy
if (InfoGain > bestInfoGain):
bestInfoGain = InfoGain
bestFeature = i
return bestFeature
def splitDataSet(dataset, axis, val):
retDataSet = []
for featVec in dataset:
if featVec[axis] == val:
reducedFeatVec = featVec[:axis]
reducedFeatVec.extend(featVec[axis+1:])
retDataSet.append(reducedFeatVec)
return retDataSet
def calcShannonEnt(dataset):
numexaples = len(dataset)
labelCounts = {}
for featVec in dataset:
currentlabel = featVec[-1] # 标签所在处
if currentlabel not in labelCounts.keys():
labelCounts[currentlabel] = 0
labelCounts[currentlabel] += 1
ShannonEnt = 0
for key in labelCounts:
prop = float(labelCounts[key])/numexaples
ShannonEnt -= prop*log(prop, 2)
return ShannonEnt
# 递归生成树节点
def createTree(dataset, labels, featlabels):
classList = [example[-1] for example in dataset] # []里为标签位置
if classList.count(classList[0]) == len(classList):
return classList[0]
if len(dataset[0]) == 1:
return majorityCnt()
bestFeat = chooesBeatFeatureToSpilt(dataset)
bestFeatLabel = labels[bestFeat]
featlabels.append(bestFeatLabel)
myTree = {bestFeatLabel: {}}
del labels[bestFeat]
featValue = [example[bestFeat] for example in dataset]
unqiqueVals = set(featValue)
for value in unqiqueVals:
sublabels = labels[:]
myTree[bestFeatLabel][value] = createTree(splitDataSet(
dataset, bestFeat, value), sublabels, featlabels)
return myTree
可视化方法一
用graphviz
from sklearn.tree import export_graphviz
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
iris = load_iris()
x = iris.data[:, 2:]
y = iris.target
tree_clf = DecisionTreeClassifier(max_depth=2)
tree_clf.fit(x, y)
export_graphviz(tree_clf,
out_file="iris_tree.dot",
feature_names=iris.feature_names[2:],
class_names=iris.target_names,
rounded=True,
filled=True)
把cmd的目标转入目的文档(用cd),输入以下指令
dot -Tpng iris_tree.dot(目标文件) -o iris_tree.png
边界可视化
from matplotlib.colors import ListedColormap
import matplotlib.pyplot as plt
def plot_decision_boudary(clf,X,y,axes=[0,7.5,0,3],iris=False,legend=False,plt_training=True):
x1s=np.linspace(axes[0],axes[1],100)
x2s=np.linspace(axes[2],axes[3],100)
x1,x2=np.meshgrid(x1s,x2s)
X_new=np.c_[x1.ravel(),x2.ravel()]
y_pred=clf.predict(X_new).reshape(x1.shape)
custom_cmap=ListedColormap(['#FF0000','#008000','#0000FF'])
plt.contourf(x1,x2,y_pred,alpha=0.3,cmap=custom_cmap)
if not iris:
custom_cmap2=ListedColormap(['#FF0000','#008000','#0000FF'])
plt.contour(x1,x2,y_pred,cmap=custom_cmap2,alpha=0.8)
if plt_training:
plt.plot(X[:,0][y==0],X[:,1][y==0],"yo",label="1")
plt.plot(X[:,0][y==1],X[:,1][y==1],"bs",label="2")
plt.plot(X[:,0][y==2],X[:,1][y==2],"g^",label="3")
plt.axis(axes)
if iris:
plt.xlabel("length",fontsize=14)
plt.ylabel("width",fontsize=14)
else:
plt.xlabel(r"$x_1$",fontsize=18)
plt.ylabel(r"$x_2$",fontsize=18,rotation=0)
if legend:
plt.legend(loc="lower right",fontsize=14)
防止太过详细(过拟合)
#max_depth=2(最大深度)
#min_samples_split(节点分割前必须具有的最小样本数)
#min_samples_leaf(叶子节点必须具有的最小样本数)
#max_leaf_nodes(叶子节点的最大数量)
#max_features(在每个节点处评估用于拆分的最大特征数)
实战
回归树
from sklearn.tree import export_graphviz
import numpy as np
from sklearn.tree import DecisionTreeClassifier
from sklearn.tree import DecisionTreeRegressor
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
np.random.seed(42)
m = 200
X = np.random.rand(m, 1)
y = 4*(X-0.5)**2
y = y+np.random.rand(m, 1)/10
tree_reg = DecisionTreeRegressor(max_depth=4)
tree_reg.fit(X, y)
export_graphviz(tree_reg,
out_file="reg_tree.dot",
feature_names=["x1"],
rounded=True,
filled=True)