这里只写一下用C++简单实现的ID3算法决策树
ID3算法是基于信息熵和信息获取量
每次建立新节点时,选取一个信息获取量最大(以信息熵为衡量)的属性进行分割
决策树还有很多其他算法,不过都只是衡量标准不同
实质都是按照贪心自上而下地建树
如果深度过深,还要采取剪枝的手段
#include
#include
#include
#include
#include
using namespace std;
typedef unsigned int ui;
typedef vector< vector> dv;
const int maxm = 100, maxn = 1000;
const double eps = 1e-7;
struct Node
{
bool flag[maxm];
int st, yes, no;
}node[maxn]; //结点,flag表示已采用的属性,st为此次划分的标准
double cal_entropy(double p) //计算信息熵
{
if(abs(p) <= eps || abs(p-1) <= eps) return 0;
return -(p*log(p)/log(2)+(1-p)*log(1-p)/log(2));
}
double split(dv v, int k) //算出如果以第k个属性分割得到的信息获取量
{
int v1, v2, n1, n2;
v1 = v2 = n1 = n2 = 0;
for(ui i = 0; i < v.size(); i++)
{
if(v[i][k])
{
n1++;
if(v[i][v[i].size()-1]) v1++;
}
else
{
n2++;
if(v[i][v[i].size()-1]) v2++;
}
}
int n = n1+n2;
double ans = (double)n1/n*cal_entropy((double)v1/n1) + (double)n2/n*cal_entropy((double)v2/n2);
return cal_entropy((double)(v1+v2)/n) - ans;
}
void build(int x, dv vnode) //按照贪心算法建树
{
double ans = -1;
int k = -1;
for(ui i = 0; i < vnode.size(); i++)
if(vnode[i][vnode[i].size()-1]) node[x].yes++;
node[x].no = vnode.size() - node[x].yes;
for(ui i = 0; i < vnode[0].size()-1; i++)
if(!node[x].flag[i] && (split(vnode, i) - ans > eps))
{
ans = split(vnode, i);
k = i;
}
node[x].st = k;
printf("%d %d %d %d\n", x, node[x].yes, node[x].no, node[x].st); //先序遍历输出树的结构
if(k == -1) return;
dv v1, v2;
for(ui i = 0; i < vnode.size(); i++)
if(vnode[i][k]) v1.push_back(vnode[i]);
else v2.push_back(vnode[i]);
for(ui i = 0; i < v1[0].size(); i++)
{
node[x*2].flag[i] = node[x].flag[i];
node[x*2+1].flag[i] = node[x].flag[i];
}
node[x*2].flag[k] = node[x*2+1].flag[k] = 1;
build(x*2, v1); build(x*2+1, v2);
}
int n, m, x;
dv v;
int dfs(int x, vector vv) //用于测试集
{
if(node[x].st == -1) return node[x].yes > node[x].no;
if(vv[node[x].st]) return dfs(2*x, vv);
else return dfs(2*x+1, vv);
}
vector vv;
int main()
{
freopen("a.txt", "r", stdin);
cin>>n>>m;
v.resize(n);
for(int i = 0; i < n; i++)
for(int j = 0; j < m; j++)
{
cin>>x;
v[i].push_back(x);
}
build(1, v);
for(int i = 0; i < m; i++) cin>>x, vv.push_back(x);
cout<
}