import csv
from random import shuffle
#读取train.csv、test.csv
train = []
with open('D:\\train.csv', 'rt') as f:
csvread = csv.reader(f)
for i in csvread:
i[0] = float(i[0])
i[1] = float(i[1])
i[2] = int(i[2])
train.append(i)
test = []
with open('D:\\test.csv', 'rt') as f:
csvread = csv.reader(f)
for i in csvread:
i[0] = float(i[0])
i[1] = float(i[1])
test.append(i)
#将训练集按4:1划分成估计集和验证集
train0 = []
train1 = []
gj = []
yz = []
for sample in train:
if sample[-1] == 0:
train0.append(sample)
else:
train1.append(sample)
l0 = int(len(train0) / 5)
l1 = int(len(train1) / 5)
shuffle(train0)
shuffle(train1)
yz = yz+train0[:l0]+train1[:l1]
gj = gj+train0[l0:]+train1[l1:]
#Gini指数
def Gini(D):
D0 = []
D1 = []
for d in D:
if d[-1] == 0:
D0.append(d)
else:
D1.append(d)
return 1-(len(D0)/len(D))**2-(len(D1)/len(D)**2)
#将集合按照给定的特征和值进行划分
def subD(D, n, value):
Dl = []
Dr = []
for sample in D:
if sample[n]<=value:
Dl.append(sample)
else:
Dr.append(sample)
return Dl, Dr
#选出最好的划分点
def best_feature(D):
gini = 10
a = 0
s = 0
for i in range(0,2):
for j in range(len(D)):
Dl, Dr = subD(D, i, D[j][i])
if len(Dl)==0 or len(Dr)==0:
continue
temp_gini = (len(Dl)/len(D))*Gini(Dl) + (len(Dr)/len(D))*Gini(Dr)
if temp_gini < gini:
gini = temp_gini
a = i
s = D[j][i]
return a, s, gini
#定义树的节点
class Node():
def __init__(self, D, n, value, lchild=None, rchild=None):
self.D = D
self.n = n
self.value = value
self.lchild = lchild
self.rchild = rchild
#创建CART树
def createtree(D):
n, value,gini = best_feature(D)
node = Node(D, n, value)
Dl, Dr = subD(D, n, value)
if len(D) > 200:
node.lchild = createtree(Dl)
node.rchild = createtree(Dr)
return node
#根据得到的CART树对测试集进行预测
def testtree(tree, test):
rig = 0
for sample in test:
node = tree
D = tree.D
while node.lchild!=None and node.rchild!=None:
if sample[node.n] <= node.value:
node = node.lchild
D = node.D
else:
node = node.rchild
D = node.D
D0 = []
D1 = []
for i in D:
if i[-1]==0:
D0.append(i)
else:
D1.append(i)
if len(D0) >= len(D1):
t = 0
else:
t = 1
if t == sample[-1]:
rig+=1
return rig/len(test)
tree = createtree(gj)
print('预测的正确率:', round(testtree(tree, yz), 4))
def pre_test(tree, test):
for sample in test:
node = tree
D = tree.D
while node.lchild!=None and node.rchild!=None:
if sample[node.n] <= node.value:
node = node.lchild
D = node.D
else:
node = node.rchild
D = node.D
D0 = []
D1 = []
for i in D:
if i[-1]==0:
D0.append(i)
else:
D1.append(i)
if len(D0) >= len(D1):
t = 0
else:
t = 1
sample.append(t)
pre_test(tree, test)
with open('D://pre.csv', 'wt', newline='') as f:
csvin = csv.writer(f)
csvin.writerows(test)
机器学习--python实现CART分类树--原生代码
于 2022-05-30 14:28:59 首次发布