python实现决策树数据直接赋值导入_决策树在python中的数据实现

我为python决策树算法实现完成了以下代码:from csv import reader

def load_csv(filename):

file = open(filename, "rb")

lines = reader(file)

dataset = list(lines)

return dataset

# Split a dataset based on an attribute and an attribute value

def test_split(index, value, dataset):

left, right = list(), list()

for row in dataset:

if row[index] < value:

left.append(row)

else:

right.append(row)

return left, right

# Calculate the Gini index for a split dataset

def gini_index(groups, classes):

# count all samples at split point

n_instances = float(sum([len(group) for group in groups]))

# sum weighted Gini index for each group

gini = 0.0

for group in groups:

size = float(len(group))

# avoid divide by zero

if size == 0:

continue

score = 0.0

# score the group based on the score for each class

for class_val in classes:

p = [row[-1] for row in group].count(class_val) / size

score += p * p

# weight the group score by its relative size

gini += (1.0 - score) * (size / n_instances)

return gini

# Select the best split point for a dataset

def get_split(dataset):

class_values = list(set(row[-1] for row in dataset))

b_index, b_value, b_score, b_groups = 999, 999, 999, None

for index in range(len(dataset[0])-1):

for row in dataset:

groups = test_split(index, row[index], dataset)

gini = gini_index(groups, class_values)

if gini < b_score:

b_index, b_value, b_score, b_groups = index, row[index], gini, groups

return {'index':b_index, 'value':b_value, 'groups':b_groups}

# Create a terminal node value

def to_terminal(group):

outcomes = [row[-1] for row in group]

return max(set(outcomes), key=outcomes.count)

# Create child splits for a node or make terminal

def split(node, max_depth, min_size, depth):

left, right = node['groups']

del(node['groups'])

# check for a no split

if not left or not right:

node['left'] = node['right'] = to_terminal(left + right)

return

# check for max depth

if depth >= max_depth:

node['left'], node['right'] = to_terminal(left), to_terminal(right)

return

# process left child

if len(left) <= min_size:

node['left'] = to_terminal(left)

else:

node['left'] = get_split(left)

split(node['left'], max_depth, min_size, depth+1)

# process right child

if len(right) <= min_size:

node['right'] = to_terminal(right)

else:

node['right'] = get_split(right)

split(node['right'], max_depth, min_size, depth+1)

# Build a decision tree

def build_tree(train, max_depth, min_size):

root = get_split(train)

split(root, max_depth, min_size, 1)

return root

# Print a decision tree

def print_tree(node, depth=0):

if isinstance(node, dict):

print('%s[X%d < %.3f]' % ((depth*' ', (node['index']+1), node['value'])))

print_tree(node['left'], depth+1)

print_tree(node['right'], depth+1)

else:

print('%s[%s]' % ((depth*' ', node)))

# load and prepare data

filename = 'spine.csv'

dataset = load_csv(filename)

tree = build_tree(dataset, 1, 1)

print_tree(tree)

该数据集共包含13个与脊柱相关的属性,数据挖掘算法用来发现一个人的脊椎是否正常

对于数据集脊椎.csv链接如下:

它显示以下错误:

^{pr2}$

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值