上一篇https://blog.csdn.net/To_be_to_thought/article/details/86321842介绍了在字母表下Trie树的两种实现方式,如果不是字母表,而是一些具有层次关系、包含关系的对象呢,比如如下问题:
测试数据如下:
机器学习,线性模型,线性回归,最小二乘法
,神经网络,神经元模型,激活函数
,,多层网络,感知机
,,,连接权
,强化学习,有模型学习,策略评估
,,,策略改进
,,免模型学习,蒙特卡洛方法
,,,时序差分学习
,,模仿学习,直接模仿学习
,,,逆强化学习
将上述简化版的目录转成json:
我一开始想到的就是用Trie的思想,而树节点中存储的是字符串对象,不再是字符,这时数组映射的小技巧显然是无法用了,这里我采用通用的HashMap来做存储和映射,树节点的定义如下:
private class TrieNode{
String[] record;
HashMap<String,TrieNode> map;
boolean isEnd;
public TrieNode()
{
this.record=null;
this.map=new HashMap<>();
}
}
这里使用逻辑变量isEnd来标记从根节点到当前树节点的路径形成的字符串数组是否存在于“字典”中,如果为false,则record==null,如果为true,则record记录了这条路径上的字符串。
那么整棵树的效果图如下(这里只将第一条记录完整的插入树中,其他记录因为篇幅问题省略),isEnd=true用蓝色填充:
树节点还有插入和查询等方法:
private class TrieNode{
String[] record;
HashMap<String,TrieNode> map;
boolean isEnd;
public TrieNode()
{
this.record=null;
this.map=new HashMap<>();
}
public boolean containsKey(String str)
{
return this.map.containsKey(str);
}
public TrieNode get(String str)
{
if(this.map.containsKey(str))
return this.map.get(str);
else
return null;
}
public void put(String str,TrieNode node)
{
assert str!=null && node!=null;
this.map.put(str,node);
}
public void setEnd()
{
this.isEnd=true;
}
public boolean isEnd()
{
return this.isEnd;
}
}
整个目录树的定义使用一个根节点和一个全局的HashMap来实现的,这个HashMap存储着节点字符串值到树节点的映射,这样可以更加快速的实现查询、插入操作。
public class CatalogTrie {
private TrieNode root;
HashMap<String,TrieNode> AllMap;
public CatalogTrie()
{
this.root=new TrieNode();
this.AllMap=new HashMap<>();
}
}
整个基于HashMap的字典树的实现和测试代码如下:
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
public class CatalogTrie {
private TrieNode root;
HashMap<String,TrieNode> AllMap;
public CatalogTrie()
{
this.root=new TrieNode();
this.AllMap=new HashMap<>();
}
private class TrieNode{
String[] record;
HashMap<String,TrieNode> map;
boolean isEnd;
public TrieNode()
{
this.record=null;
this.map=new HashMap<>();
}
public boolean containsKey(String str)
{
return this.map.containsKey(str);
}
public TrieNode get(String str)
{
if(this.map.containsKey(str))
return this.map.get(str);
else
return null;
}
public void put(String str,TrieNode node)
{
assert str!=null && node!=null;
this.map.put(str,node);
}
public void setEnd()
{
this.isEnd=true;
}
public boolean isEnd()
{
return this.isEnd;
}
}
public void insert(String[] catalogue)
{
TrieNode node=this.root;
assert catalogue!=null && catalogue.length>=1;
for(int i=0;i<catalogue.length;i++)
{
String str=catalogue[i];
if(!node.containsKey(str))
{
TrieNode tmp=new TrieNode();
node.put(str,tmp);
this.AllMap.put(str,node);
}
node=node.get(str);
}
node.setEnd();
node.record=catalogue.clone();
}
public void insert(TrieNode father,String fatherKey,String[] catalogue,int start)
{
assert father.containsKey(fatherKey);
TrieNode node=father.get(fatherKey);
for(int i=start;i<catalogue.length;i++)
{
String str=catalogue[i];
if(!node.containsKey(str))
{
TrieNode tmp=new TrieNode();
node.put(str,tmp);
this.AllMap.put(str,node);
}
node=node.get(str);
}
}
public TrieNode searchPrefix(String[] parent)
{
assert parent!=null && parent.length>=1;
TrieNode node=this.root;
for(int i=0;i<parent.length;i++)
{
String curStr=parent[i];
if(node.containsKey(curStr))
node=node.get(curStr);
else
return null;
}
return node;
}
public TrieNode searchPrefix(String parent)
{
return this.AllMap.get(parent);
}
/*
public Iterable<String> subCatalogues(String parent)
{
}
*/
public boolean search(String[] parent)
{
assert parent!=null && parent.length>=1;
TrieNode node=searchPrefix(parent);
return node!=null && node.isEnd();
}
public static ArrayList<String[]> readTXTByLine(String fileName) {
File file = new File(fileName);
ArrayList<String[]> ret=new ArrayList<>();
BufferedReader reader = null;
try {
System.out.println("以行为单位读取文件内容,一次读一整行:");
reader = new BufferedReader(new FileReader(file));
String tempString = null;
int line = 1;
// 一次读入一行,直到读入null为文件结束
while ((tempString = reader.readLine()) != null)
{
// 显示行号
String[] lines=tempString.split(",");
ret.add(lines);
line++;
}
reader.close();
} catch (IOException e)
{
e.printStackTrace();
} finally
{
if (reader != null)
{
try {
reader.close();
} catch (IOException e1) {
}
}
}
return ret;
}
public static ArrayList<Integer> commaCount(ArrayList<String[]> strs)
{
ArrayList<Integer> ret=new ArrayList<>();
for(int i=0;i<strs.size();i++)
{
int count=0;
String[] tmp=strs.get(i);
for(String str:tmp)
{
if(str.equals(""))
count++;
else
break;
}
ret.add(count);
}
return ret;
}
public static int[] fatherIndex(ArrayList<Integer> comma)
{
int[] ret=new int[comma.size()];
for(int i=comma.size()-1;i>=0;i--)
{
for(int j=i-1;j>=0;j--)
{
if(comma.get(i)>comma.get(j))
{
ret[i]=j;
break;
}
}
}
ret[0]=-1;
return ret;
}
public static CatalogTrie MakeTrie(String filename)
{
ArrayList<String[]> rawData=readTXTByLine(filename);
ArrayList<Integer> n_comma=commaCount(rawData);
int[] fatherIndex=fatherIndex(n_comma);
CatalogTrie ret=new CatalogTrie();
ret.insert(rawData.get(0));
for(int i=1;i<rawData.size();i++)
{
String[] tmp=rawData.get(fatherIndex[i]);
String key=tmp[n_comma.get(fatherIndex[i])];
TrieNode fatherNode=ret.AllMap.get(key);
ret.insert(fatherNode,key,rawData.get(i),n_comma.get(i));
}
return ret;
}
}
测试代码:
public static void main(String[] args)
{
//测试一
String[] a={"机器学习","线性模型","线性回归","最小二乘法"};
String[] b={"机器学习","神经网络","神经元模型","激活函数"};
String[] c={"机器学习","神经网络","多层网络","感知机"};
String[] d={"机器学习","强化学习","有模型学习","策略评估"};
String[] e={"机器学习","强化学习","有模型学习","策略改进"};
CatalogTrie CT=new CatalogTrie();
CT.insert(a);
CT.insert(b);
CT.insert(c);
CT.insert(d);
CT.insert(e);
//测试二
String filename="C:/Users/江/Desktop/数据转换题目1/concepts.txt";
MakeTrie(filename);
}