人工智能原理(二)——决策树

一 目标:使用决策树进行分类

二 评价属性

1、 信息增益(ID3):

2、信息增益率(C4.5):

3、基尼指数:

三 代码

import numpy as np
from math import log

message_range = [4, 4, 4, 3, 3, 3] # 不同特征的取值范围
# 矩阵中各列数值含义
# global message_to_num = [{},{},{},{},{},{}]
# message_to_num[0] = {'low':0, 'med':1, 'high':2, 'vhigh':3}
# message_to_num[1] = message_to_num[0]
# message_to_num[2] = {'2':0, '3':1, '4':2, '5more':3}
# message_to_num[3] = {'2':0, '4':1, 'more':2}
# message_to_num[4] = {'small':0, 'med':1, 'big':2}
# message_to_num[5] = {'low':0, 'med':1, 'high':2}

def readMat(file):
    length = len(file.readlines())
    file.seek(0, 0)
    mat = np.zeros((length, 7), int)
    i = 0
    for content in file.readlines():
        content = content[:-1]
        arr = content.split(',')
        for index in range(7):
            mat[i][index] = arr[index]
        i += 1
    return mat
       
def readTrainData(k):
    mat_ans = []
    for i in range(k):
        with open('train_mat{}.txt'.format(i), 'r') as file:
            mat = readMat(file)
            mat_ans.append(mat)
    return mat_ans

def readTestData():
    with open('test.txt', 'r') as file:
        return readMat(file)

############################
#      ID3计算相关函数      #
############################

def getEntropy(p):
    if p == 0 or p == 1:
        return 0
    ans = -(p * log(p, 2) + (1 - p) * log(1 - p, 2))
    return ans

def getRemainder(l):
    tot = l.sum()
    ans = 0
    for row in l:
        if row[0] == 0 and row[1] == 0:
            continue
        ans += getEntropy(row[0] / (row[0] + row[1])) * ((row[0] + row[1]) / tot)
    return ans

def getGain(l):
    tmp = l.sum(axis=0)
    ans = getEntropy(tmp[0] / (tmp[0] + tmp[1])) - getRemainder(l)
    return ans

#############################
#      C4.5计算相关函数      #
#############################

def getSplit(l):
    tmp = l.sum(axis=1)
    tot = tmp.sum()
    ans = 0
    for ele in tmp:
        if ele == 0:
            continue
        ans += (ele / tot) * log(ele / tot)
    return -ans

def getRatio(l):
    return getGain(l) / getSplit(l)

#############################
#      CART计算相关函数      #
#############################

def getGiniHelper(row):
    tmp = row.sum()
    return 1 - pow(row[0] / tmp, 2) - pow(row[1] / tmp, 2)

def getGini(l):
    tmp = l.sum(axis=1)
    tot = tmp.sum()
    ans = 0
    for i in range(l.shape[0]):
        ans += (tmp[i] / tot) * getGiniHelper(l[i])
    return ans

##########################
# 生成决策树并输出其准确率 #
##########################

# 决策树结点设计如下
# [index, target, [node], res)]
# index为选用特征的序号,target为所用指标的评价,[node]代表子节点,res代表取值
# 对于该特征的某一取值,使用(取值+2)进行检索

# 决策树使用深度优先的方法建立

# 进行矩阵拼接,生成训练集
def buildTrainSet(mat_list, index, k):
    if index != 0:
        train_mat = mat_list[0]
    else:
        train_mat = mat_list[1]
    for i in range(1, k):
        if i != index:
            train_mat = np.vstack((train_mat, mat_list[i]))
    return train_mat

# 生成用于计算信息熵、信息增益率或基尼指数的矩阵
def getList(mat, index):
    l = np.zeros((message_range[index], 2), int) # l每一行代表一种取值,两列数据分别为
    for row in mat:
        l[row[index]][row[6]] += 1 # row[index]为该行在指定特征的取值,row[6]为该行的结果
    return l

def testFeature(l, sign):
    if sign == 0:
        return getGain(l)
    elif sign == 1:
        return getRatio(l)
    else:
        return -getGini(l) # 使用基尼指数的相反数,将所有标准转变为求最大值

# 生成用于某一个子节点的子矩阵
def getSubmatrix(mat, index, value):
    ans = np.zeros((0,7), int)
    for row in mat:
        if row[index] == value:
            ans = np.vstack((ans, row))
    return ans

def createNode(mat, index_list, sign, node):
    features = []
    for i in index_list:
        l = getList(mat, i)
        tmp = testFeature(l, sign)
        features.append((tmp, i))
    features.sort()
    best = features[-1] # 选取最优特征
    # 优化:限制树的深度
    if len(features) == 1:
        node.append(best[1])
        node.append(best[0])
        l = getList(mat, best[1])
        for i in range(l.shape[0]):
            if l[i][0] == 0:
                node.append(1)
            else:
                node.append(0)
        return

    # 移除最优特征的序号
    new_index_list = index_list[:]
    new_index_list.remove(best[1])
    
    # 为结点赋值
    node.append(best[1])
    node.append(best[0])
    l = getList(mat, best[1])
    for i in range(l.shape[0]):
        if l[i][0] == 0:
            node.append(1)
        elif l[i][1] == 0:
            node.append(0)
        else:
            node.append([])
            submatrix = getSubmatrix(mat, best[1], i)
            createNode(submatrix, new_index_list[:], sign, node[-1])

def getTree(mat_list, index, sign, k):
    # 生成训练集和验证集
    train_mat = buildTrainSet(mat_list, index, k)
    index_list = [0, 1, 2, 3, 4, 5] # 属性列表,每建立一层结点则移除对应属性的序号 
    root = []
    createNode(train_mat, index_list, sign, root)
    return root

########################
# 验证集与测试集相关函数 #
########################

# 对某一组数据进行预测
def getPredict(node, row):
    index = node[0]
    value = row[index]
    if type(node[value + 2]) == int:
        return node[value + 2]
    else:
        return getPredict(node[value + 2], row)

# 测试集进行预测
def testsetPredice(mat, root):
    with open('test_res.txt', 'w') as file:
        for row in mat:
            predi = getPredict(root, row)
            file.write(str(predi) + '\n')

# 验证集进行验证
def validsetPredict(mat, root, file):
    # file.write(str(root) + '\n')
    count = 0
    length = mat.shape[0]
    for row in mat:
        predi = getPredict(root, row)
        if predi == row[6]:
            count += 1
    file.write('ratio: {}\n'.format(float(count / length)))

# 寻找最优决策树
def findBestTree(mat_list, k):
    with open('journal.txt', 'w') as file:
        # 训练集与测试集比例为9:1
        for s in range(3):
            for i in range(k):
                file.write('{} {}\n'.format(s, i))
                root = getTree(mat_list, i, s, k)
                validsetPredict(mat_list[i], root, file) 

# main函数
def main():
    # 读取数据
    mat_list = readTrainData(10)
    test_mat = readTestData()

    # 测试最佳参数组合
    findBestTree(mat_list, 10)

    # 对测试集进行预测
    root = getTree(mat_list, 5, 1, 10)
    testsetPredice(test_mat, root)

if __name__ == '__main__':
    main()

 

没有更多推荐了,返回首页