决策树python

1.准备(linux):
(1)sudo apt-get install graphviz
(2)sudo pip install graphviz
(3)sudo pip install pydotplus

2.评判标准(criterion):
可以选基尼系数或者信息增益熵
criterion = ‘gini’
criterion = ‘entropy’

3代码(kaggle中NBA球员的位置进行分类(控球后卫,中锋),特征包括2分球,与助攻):

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib as mpl
import pydotplus

from sklearn import tree
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from sklearn.preprocessing import LabelEncoder

cBackground = mpl.colors.ListedColormap(['#A0FFA0', '#FFA0A0', '#A0A0FF'])
cPos = mpl.colors.ListedColormap(['g','r','b'])
positions = 'SF'
positionsNSF = 'PG', 'C'
sF = 'AST', '2P'

players = pd.read_csv('Players.csv')
seasionStat = pd.read_csv('Seasons_Stats.csv')
totalPos, totalAST, total2P = seasionStat['Pos'], seasionStat[sF[0]], seasionStat[sF[1]]


pos, analysis = [],[]

for ind, text in enumerate(totalPos):
    if text in positionsNSF:
        pos.append(text)
        analysis.append([totalAST[ind], total2P[ind]])


labels = LabelEncoder().fit_transform(pos)
xTrain, xTest, labelTrain, labelTest = train_test_split(analysis, labels, train_size = 0.7)

model = DecisionTreeClassifier(criterion = 'gini')
model.fit(xTrain, labelTrain)
testHat = model.predict(xTest)
print 'accuracy: ', accuracy_score(testHat, labelTest)

chunks = 50
analysis = np.array(analysis).T
astMin, pstMin = min(analysis[0]), min(analysis[1])
astMax, pstMax = max(analysis[0]), max(analysis[1])

astAxis, pstAxis = np.linspace(astMin, astMax, chunks), np.linspace(pstMin, pstMax, chunks)
xGrid, yGrid = np.meshgrid(astAxis, pstAxis)
xyStack = np.stack((xGrid.flat,yGrid.flat), axis = 1)
yHat = model.predict(xyStack).reshape(xGrid.shape)
xTest = np.array(xTest).T

plt.pcolormesh(xGrid, yGrid, yHat, cmap = cBackground)
plt.scatter(xTest[0], xTest[1], c = testHat.ravel(), s = 40, cmap = cPos)
plt.xlim(astMin, astMax)
plt.ylim(pstMin, pstMax)
plt.xlabel(sF[0])
plt.ylabel(sF[1])
plt.title('PG and C')
plt.grid()
plt.show()

analysis = []
for ind, text in enumerate(totalPos):
    if text == 'SF':
        analysis.append([totalAST[ind], total2P[ind]])

analysis = np.array(analysis).T
astMin, pstMin = min(analysis[0]), min(analysis[1])
astMax, pstMax = max(analysis[0]), max(analysis[1])
plt.scatter(analysis[0], analysis[1], c = 'r')
plt.grid()
plt.xlim(astMin, astMax)
plt.ylim(pstMin, pstMax)
plt.xlabel(sF[0])
plt.ylabel(sF[1])
plt.title('SF')
plt.show()

4.结果示意图(测试数据,准确率为88%):
这里写图片描述
5.这里面也画了小前锋(SF的数据:
这里写图片描述
这里可以看出小前锋的数据比较杂乱,这也表明了小前锋比较全能。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值