id3算法java实现_ID3算法(Java实现)

packagecom.id3.node;importjava.io.BufferedReader;importjava.io.File;importjava.io.FileInputStream;importjava.io.FileNotFoundException;importjava.io.FileWriter;importjava.io.IOException;importjava.io.InputStream;importjava.io.InputStreamReader;importjava.text.DecimalFormat;importjava.util.ArrayList;importjava.util.HashMap;importjava.util.List;importjava.util.Map;importjava.util.Properties;/*** ID3算法

*@authorJoMint

**/

public classID3Alogo {//存每个节点及其属性等相关变量

private ListtreeList;//存数据集

private List>dataList;//遍历决策树时的开始节点

privateAttributes startNode;//决策结果变量的值

private ListresultList;//结果属性节点

privateTreeNode resultNode;//决策树

privateString str;//构建决策树的开始调用方法

public voidID3(String id3Name,String readPath,String printPath){//初始化成员变量

initElement(id3Name);//读数据

readData(readPath);//构建决策树

cusTree(dataList, treeList, startNode);//System.out.println(startNode.getNextNode().getAttributes().get("Overcast").getLeafName());//遍历决策树,并把结果存入str中

printTree(startNode,"");//打印决策树

System.out.println(str);//输出决策树到文件

printTreetoTxt(printPath);

}/*** 初始化成员变量*/

private voidinitElement(String id3Name) {//存每个节点及其属性等相关变量

treeList = new ArrayList();//存数据集

dataList = new ArrayList>();//遍历决策树时的开始节点

startNode = newAttributes();//决策结果变量的值

resultList = new ArrayList();//结果属性节点

TreeNode resultNode = null;//决策树

str = id3Name+"决策树:\r\n";

}/*** 读数据*/

private voidreadData(String path) {

MapdataMap;

MapattrMap;

TreeNode treeNode;intnum;//创建读取properties文件的对象

Properties pro = newProperties();try{//为了读取中文字符,将读取文件的类型改为字符流读取

InputStream inputStream = newFileInputStream(path);

BufferedReader bf= new BufferedReader(newInputStreamReader(inputStream));//加载数据文件

pro.load(bf);//读取数据总个数

num = Integer.parseInt(pro.getProperty("datanum"));//读取属性及属性值

String attribute = pro.getProperty("nodeAndAttribute");//将每个属性分开,用数组存,遍历每个属性,再把每个属性的属性值分开,存到treeList中

String[] attArray = attribute.split(",");for (int i = 0; i < attArray.length; i++) {

treeNode= newTreeNode();

String[] temp= attArray[i].split(":");

String nodeName= temp[0];

String[] attr= temp[1].split("/");

treeNode.setNodeName(nodeName);

attrMap= new HashMap();

Attributes attributes;for (int j = 0; j < attr.length; j++) {//Map map = new HashMap();

attributes = newAttributes();//map.put(attr[j], 0);

attributes.setAttrName(attr[j]);

attrMap.put(attr[j], attributes);//存入结果变量的值,为最后的判断做铺垫

if(i == attArray.length-1){

resultList.add(attr[j]);

}

}

treeNode.setAttributes(attrMap);

treeList.add(treeNode);

}//遍历数据集,将数据按行存入dataList中

for (int i = 1; i <= num; i++) {

dataMap= new HashMap();

String key= "D"+i;

String[] colline= pro.getProperty(key).split(",");//System.out.println(key+"=="+colline.length);

for (int j = 0; j < treeList.size(); j++) {//System.out.println(treeList.size());

dataMap.put(treeList.get(j).getNodeName(), colline[j]);

}

dataList.add(dataMap);

}//得到结果属性的名字

resultNode = treeList.get(treeList.size()-1);//System.out.println("************************resultNode==" + resultNode + "***********************");

} catch(FileNotFoundException e) {//TODO Auto-generated catch block

e.printStackTrace();

}catch(IOException e) {//TODO Auto-generated catch block

e.printStackTrace();

}

}/*** 数据处理

*@paramcdataList

*@paramctreeList*/

private List dealData(List> dataList, ListtreeList){

List returnList= new ArrayList();int num =dataList.size();/** 统计数据集中每个属性的属性值个数*/Map attrMap = new HashMap();

MapresultMap;for (int i = 0; i < treeList.size(); i++) {for (int j = 0; j < dataList.size(); j++) {//获得当前数据集中当前列当前行的属性值

String key =dataList.get(j).get(treeList.get(i).getNodeName());

attrMap=treeList.get(i).getAttributes();//System.out.println(attrMap.get(key)+"=="+key);//计算样本中对应的属性变量的个数

attrMap.get(key).setAttrNum(attrMap.get(key).getAttrNum()+1);//System.out.println("->"+attrMap.get(key));//获得结果变量值

String result = dataList.get(j).get(treeList.get(treeList.size()-1).getNodeName());

resultMap=attrMap.get(key).getResultNum();//如果包含这个结果变量,则数量上加1; 如果不包含,赋初值为1

if(resultMap.containsKey(result)) {

resultMap.put(result, resultMap.get(result)+1);

}else{

resultMap.put(result,1);

}

}

}/** 计算熵*/DecimalFormat df= new DecimalFormat("#.###");for (int i = 0; i < treeList.size(); i++) {//遍历 Attributes//计算属性熵: gain

double gain = 0.0;for (Map.Entryelement : treeList.get(i).getAttributes().entrySet()) {

Attributes attr=treeList.get(i).getAttributes().get(element.getKey());

Map result =attr.getResultNum();//遍历每个 Attributes 的 resultNum//计算属性值的熵 :h

double h = 0.0;for (Map.Entryelement2 : result.entrySet()) {double resultNum = (double)result.get(element2.getKey());double attrNum = (double)attr.getAttrNum();

resultNum= resultNum/attrNum;

h-= (resultNum*(Math.log(resultNum)/Math.log((double)2)));

h=Double.parseDouble(df.format(h));

attr.setH(h);//System.out.println("resultNum=========="+resultNum);

}//System.out.println(" attr==>"+attr);

gain += ((double)attr.getAttrNum()/num)*attr.getH();

gain=Double.parseDouble(df.format(gain));//System.out.println("gain=="+gain);

}

treeList.get(i).setGain(gain);//System.out.println(" gain-->"+treeList.get(i));

}//将处理好的dataList和treeList放在returnList中返回

returnList.add(dataList);

returnList.add(treeList);returnreturnList;//System.out.println("***************************************************+++++++↓");//for (int i = 0; i < treeList.size(); i++) {//System.out.println(treeList.get(i));//}//System.out.println();//for (int i = 0; i < dataList.size(); i++) {//System.out.println(dataList.get(i));//}//

//System.out.println("================================================="+num+"条数据=="+treeList.size()+"个属性");//System.out.println("***************************************************+++++++↑");

}/*** 构建决策树

*@paramdataList

*@paramtreeList*/@SuppressWarnings("unchecked")private void cusTree(List> dataList, ListtreeList, Attributes cAttr){

List curryList= new ArrayList();//处理数据

curryList=dealData(dataList, treeList);//从 curryList 中得到 dataList 和 treeList

dataList = (List>)curryList.get(0);

treeList= (List)curryList.get(1);//判断当前处理的数据集中的决策结果,若决策结果相同的个数等于总的当前处理的数据集的条数,则遍历结束//将当前的决策结果放入当前判断的属性值的后边//返回到调用这个函数的父函数

for(TreeNode treeNode : treeList) {if(treeNode.getNodeName().equals(resultNode.getNodeName())) {for(String attr : resultList) {if (treeNode.getAttributes().get(attr).getAttrNum() ==dataList.size()) {

cAttr.setLeafName(attr);return;

}

}

}

}//System.out.println("=_=_=_=_=_=_=datalist==="+dataList);//System.out.println("=_=_=_=_=_=_=treelist==="+treeList);//寻找最优解//得到根节点

TreeNode rootNode = treeList.get(0);for(TreeNode treeNode : treeList) {if(!treeNode.getNodeName().equals(treeList.get(treeList.size()-1).getNodeName())){if(treeNode.getGain()

rootNode=treeNode;

}

}

}//System.out.println("*********↓↓↓↓↓↓↓↓***********当前根节点为:"+rootNode.getNodeName()+"***********↓↓↓↓↓↓↓↓*********");

cAttr.setNextNode(rootNode);//对当前根节点的属性进行遍历,寻找下一个节点//节点名

String nodeName =rootNode.getNodeName();//属性名

String attrName = "";//属性节点

Attributes attr = newAttributes();//当前节点的属性值集合

Map attrMap =rootNode.getAttributes();//遍历节点的每个属性值

for (Map.Entryentry : attrMap.entrySet()) {

attr=attrMap.get(entry.getKey());

attrName=attr.getAttrName();//System.out.println("*****************attrName========"+attrName+"******************");//得到新的data集合对象

List> newDataList = new ArrayList>();

Map newMap = new HashMap();//String attrName = rootNode.getAttributes().get("Sunny").getAttrName();

newMap.clear();//删除dataList中已处理过的节点数据//遍历dataList

for (Mapmap : dataList) {if(map.containsKey(nodeName)){if(map.get(nodeName).equals(attrName)){

newMap= new HashMap();for (Map.Entrym : map.entrySet()) {//如果该节点不是已处理过的节点

if(!m.getKey().equals(nodeName)){//得到新的节点

newMap.put(m.getKey(), map.get(m.getKey()));

}

}//将新的节点存入newDataList中

newDataList.add(newMap);

}

}

}//System.out.println("↓↓↓↓↓↓*******************新的data集合:*******************↓↓↓↓↓↓");//for (Map map : newDataList) {//System.out.println(map);//}//获得新的tree集合对象,而且值为初值

List newTreeList = new ArrayList();//将treeList中的数据清空

clearTree(treeList);//删除treeList中已处理过的节点

for(TreeNode treeNode : treeList) {if(!treeNode.getNodeName().equals(nodeName)){

newTreeList.add(treeNode);

}

}//System.out.println("↓↓↓↓↓↓*******************新的tree集合:*******************↓↓↓↓↓↓");//for (TreeNode treeNode : newTreeList) {//System.out.println(treeNode);//}//递归调用当前函数,继续找节点

cusTree(newDataList, newTreeList,attr);

}

}/*** 输出决策树

*@paramattr*/

private voidprintTree(Attributes attr, String ceil) {

String nodeName=attr.getNextNode().getNodeName();

Map attrMap =attr.getNextNode().getAttributes();

str+= ceil+"----"+nodeName+"\r\n";for (Map.EntrynextAttr : attrMap.entrySet()) {//如果当前属性值没有下一个节点,则将当前属性值的名称及决策结果输出

if(attrMap.get(nextAttr.getKey()).getNextNode() == null){

str+= ceil+"-------"+attrMap.get(nextAttr.getKey()).getAttrName()+"\r\n";

str+= ceil+"----------"+attrMap.get(nextAttr.getKey()).getLeafName()+"\r\n";

}else{

str+= ceil+"-------"+attrMap.get(nextAttr.getKey()).getAttrName()+"\r\n";

printTree(attrMap.get(nextAttr.getKey()),"------");

}

}

}/*** 打印决策树到txt文本

*@parampath*/

private voidprintTreetoTxt(String path){if(path == null || path.equals("")) return;

File file= newFile(path);

File folder=file.getParentFile();

FileWriter fw;try{if(!folder.exists()){

folder.mkdirs();

file.createNewFile();

}

fw= newFileWriter(file);

fw.write(str);

fw.flush();

fw.close();

}catch(IOException e) {//TODO Auto-generated catch block

e.printStackTrace();

}

}/*** 还原初始数据

*@paramtreeList*/

private void clearTree(ListtreeList){for(TreeNode treeNode : treeList) {

Map map =treeNode.getAttributes();for (Map.Entryentry : map.entrySet()) {

Attributes attr=map.get(entry.getKey());

attr.setAttrNum(0);

attr.setH(0);

Map map2 =attr.getResultNum();

map2.clear();

}

treeNode.setGain(0);

}

}

}

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值