1.准备(linux):
(1)sudo apt-get install graphviz
(2)sudo pip install graphviz
(3)sudo pip install pydotplus
2.评判标准(criterion):
可以选基尼系数或者信息增益熵
criterion = ‘gini’
criterion = ‘entropy’
3代码(kaggle中NBA球员的位置进行分类(控球后卫,中锋),特征包括2分球,与助攻):
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib as mpl
import pydotplus
from sklearn import tree
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from sklearn.preprocessing import LabelEncoder
cBackground = mpl.colors.ListedColormap(['#A0FFA0', '#FFA0A0', '#A0A0FF'])
cPos = mpl.colors.ListedColormap(['g','r','b'])
positions = 'SF'
positionsNSF = 'PG', 'C'
sF = 'AST', '2P'
players = pd.read_csv('Players.csv')
seasionStat = pd.read_csv('Seasons_Stats.csv')
totalPos, totalAST, total2P = seasionStat['Pos'], seasionStat[sF[0]], seasionStat[sF[1]]
pos, analysis = [],[]
for ind, text in enumerate(totalPos):
if text in positionsNSF:
pos.append(text)
analysis.append([totalAST[ind], total2P[ind]])
labels = LabelEncoder().fit_transform(pos)
xTrain, xTest, labelTrain, labelTest = train_test_split(analysis, labels, train_size = 0.7)
model = DecisionTreeClassifier(criterion = 'gini')
model.fit(xTrain, labelTrain)
testHat = model.predict(xTest)
print 'accuracy: ', accuracy_score(testHat, labelTest)
chunks = 50
analysis = np.array(analysis).T
astMin, pstMin = min(analysis[0]), min(analysis[1])
astMax, pstMax = max(analysis[0]), max(analysis[1])
astAxis, pstAxis = np.linspace(astMin, astMax, chunks), np.linspace(pstMin, pstMax, chunks)
xGrid, yGrid = np.meshgrid(astAxis, pstAxis)
xyStack = np.stack((xGrid.flat,yGrid.flat), axis = 1)
yHat = model.predict(xyStack).reshape(xGrid.shape)
xTest = np.array(xTest).T
plt.pcolormesh(xGrid, yGrid, yHat, cmap = cBackground)
plt.scatter(xTest[0], xTest[1], c = testHat.ravel(), s = 40, cmap = cPos)
plt.xlim(astMin, astMax)
plt.ylim(pstMin, pstMax)
plt.xlabel(sF[0])
plt.ylabel(sF[1])
plt.title('PG and C')
plt.grid()
plt.show()
analysis = []
for ind, text in enumerate(totalPos):
if text == 'SF':
analysis.append([totalAST[ind], total2P[ind]])
analysis = np.array(analysis).T
astMin, pstMin = min(analysis[0]), min(analysis[1])
astMax, pstMax = max(analysis[0]), max(analysis[1])
plt.scatter(analysis[0], analysis[1], c = 'r')
plt.grid()
plt.xlim(astMin, astMax)
plt.ylim(pstMin, pstMax)
plt.xlabel(sF[0])
plt.ylabel(sF[1])
plt.title('SF')
plt.show()
4.结果示意图(测试数据,准确率为88%):
5.这里面也画了小前锋(SF的数据:
这里可以看出小前锋的数据比较杂乱,这也表明了小前锋比较全能。