摘要
该部分主要讲述基于Java语言构建字典树,包括字典树的剪枝与遍历操作。字典树原理不再赘述,代码实现部分如下
实现部分
工具类Tools.java,主要实现对大数据集的采样,以及对数据规模的统计
package main;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.OutputStreamWriter;
import java.util.HashSet;
public class Tools {
/**
* 从大数据集中拆分出小样本, size指定获取的行数
* */
public static void getSample(String src, String des, long size) throws IOException {
FileInputStream fis = new FileInputStream(src);
InputStreamReader isr = new InputStreamReader(fis);
BufferedReader br = new BufferedReader(isr);
FileOutputStream fos = new FileOutputStream(des);
OutputStreamWriter osw = new OutputStreamWriter(fos);
BufferedWriter bw = new BufferedWriter(osw);
for (int row = 0; row < size; row++) {
String line = br.readLine();
bw.write(line);
bw.write(System.lineSeparator());
}
bw.close();
osw.close();
fos.close();
br.close();
isr.close();
fis.close();
System.out.println("get sample successful, size is: " + size);
}
/**
* 提取数据文件的指定列(列从0开始计数),按行写入到新文件
* */
public static void getColumn(String src, String des, int colNum) throws IOException {
FileInputStream fis = new FileInputStream(src);
InputStreamReader isr = new InputStreamReader(fis);
BufferedReader br = new BufferedReader(isr);
FileOutputStream fos = new FileOutputStream(des);
OutputStreamWriter osw = new OutputStreamWriter(fos);
BufferedWriter bw = new BufferedWriter(osw);
// 依次读取文件每一行
String line = br.readLine();
while (line != null) {
String column = line.split("\\t")[colNum].trim(); // 提取出指定列
if (column.equals("")) { // 过滤无效空记录
;
} else {
bw.write(column);
bw.write(System.lineSeparator()); // 每个记录占一行
}
line = br.readLine();
}
bw.close();
osw.close();
fos.close();
br.close();
isr.close();
fis.close();
System.out.println("extract column to file successful, cloumnNum is: " + colNum);
}
/**
* 数据规模统计
* */
public static void getSize(String src) throws IOException {
HashSet<String> chars_set = new HashSet<String>(); // 存储单个字符
HashSet<String> terms_set = new HashSet<>(); // 存储所有纪录
FileInputStream fis = new FileInputStream(src);
InputStreamReader isr = new InputStreamReader(fis);
BufferedReader br = new BufferedReader(isr);
String line = br.readLine();
while (line != null) {
terms_set.add(line);
String[] arrs = line.split("");
for (String arr : arrs)
chars_set.add(arr);
line = br.readLine();
}
System.out.println("chars_set.size():" + chars_set.size());
System.out.println("terms_set.size():" + terms_set.size());
}
public static void main(String[] args) throws IOException {
String path1 = "";
String path2 = "";
// 抽取第一列
getColumn(path1, path2, 1);
getSize(path2);
}
}
节点类Node.java,树节点
package main;
import java.util.LinkedList;
/**
* 树节点类
* */
public class Node {
char val; // 保存当前节点的字符
int count; // 统计经过当前节点的字符串数目
boolean isEnd; // 标志当前节点是否是一个词的末尾
Node parent; // 存储当前节点的父节点
LinkedList<Node> childList; // 存储当前节点的直接子节点
int org_count; // 当前节点的原始频次
int max_org_count; // 当前节点路径上的最大org_count
/**
* 带参构造方法,构造含有指定字符val的节点,并指明新建节点的父节点
* */
public Node(char val, Node parent) {
this.val = val;
this.count = 1;
this.isEnd = false;
this.parent = parent;
this.childList = new LinkedList<Node>();
this.org_count = 0;
this.max_org_count = 0;
}
/**
* 无参构造方法,初始化val=' ', count=0, isEnd=false, parent=null, childList=new List()
* */
public Node() {
this(' ', null);
}
/**
* 根据指定的字符,获取当前节点的子节点。子节点不存在则返回null
* */
public Node getNode(char val) {
// 依次遍历链表
for (Node child : childList) {
if (child.val == val)
return child;
else
continue;
}
return null;
}
@Override
public String toString() {
return this.val + ":" + this.count;
}
}
字典树类Trietree.java
package main;
import java.util.AbstractMap.SimpleEntry;
import java.util.LinkedList;
import java.util.Map.Entry;
/**
* 定义Trie树,以Node作为节点类
* @author stevinpan
* */
public class TrieTree {
public Node root = null; // 树的根节点,不存储字符信息
private int min_count = 0; // 剪枝的最小阈值
/**
* 构造方法,初始化根节点
* */
public TrieTree() {
root = new Node();
}
/**
* 判断当前树是否存在指定字符串
* */
public boolean isExist(String word) {
Node curr = root; // 获取根结点作为当前指针
if (curr == null || word == null)
return false;
// 依次遍历当前字符串的每个串
for (int index = 0; index < word.length(); index++) {
Node next = curr.getNode(word.charAt(index)); // 获取包含当前字符的子节点
if (next != null) { // 子节点存在,指针后移
curr = next;
} else { // 子节点不存在,直接返回
return false;
}
}
// 根据curr节点的标志位,判断是否是单词结尾
if (curr.isEnd) {
System.out.println(curr.count);
return true;
}
return false;
}
/**
* 插入字符串,相同字符串也要插入
* 插入成功返回true
*
* */
public boolean insert(String word) {
if (word == null)
return false;
if (root == null)
root = new Node();
Node curr = root; // 获取根结点作为当前指针
// 依次遍历字符串,树路径上所有节点count++
for (int index = 0; index < word.length(); index++) {
Node next = curr.getNode(word.charAt(index)); // 获取包含当前字符的子节点
if (next != null) { // 存在包含当前字符的子节点,子节点count++,指针后移
next.count++;
curr = next;
} else { // 不存在包含当前字符的节点,则创建新节点,指针后移
next = new Node(word.charAt(index), curr);
curr.childList.add(next);
curr = next;
}
}
// 标注出单词结尾
curr.isEnd = true;
curr.org_count++;
return true;
}
/**
* trie树的层次遍历,每行一个职业名称
* */
public void printAll() {
printAll(root);
}
private void printAll(Node node) {
LinkedList<Node> childList = node.childList; // 获取当前节点的子节点
for (Node child : childList) { // 依次遍历每个子节点,如果子节点isEnd==true,则反向输出该节点
if (child.isEnd) {
System.out.println(child.count+"\t"+toRoot(child));
}
printAll(child);
}
}
/**
* 层次遍历,每行按层次输出
* 从根节点开始,遇到isEnd=true则输出向上的路径,层次遍历
* */
public LinkedList<SimpleEntry<Integer, String>> level() {
LinkedList<SimpleEntry<Integer, String>> result = new LinkedList<SimpleEntry<Integer, String>>();
level(root, result);
return result;
}
private void level(Node node, LinkedList<SimpleEntry<Integer, String>> results) {
// 遇到isEnd=true节点,向上回溯
if (node.isEnd) {
// 获取当前节点向上的全路径
String full_path = toRoot(node);
// 获取当前节点各个isEnd祖先节点全路径
String parent_paths = toParents(node.parent);
// 输出到List
results.add(new SimpleEntry<Integer, String>(node.count, node.org_count + "\t" + full_path + parent_paths));
// 递归访问子节点
for (Node child : node.childList) {
level(child, results);
}
} else { // 依次遍历子节点
for (Node child : node.childList) {
level(child, results);
}
}
}
/**
* 根据当前节点,向上反向输出,直至root节点
* */
public String toRoot(Node node) {
if (node != root) {
return node.val + toRoot(node.parent);
} else {
return "";
}
}
/**
* 根据当前节点,获取所有祖先节点的路径,以"\t"分隔输出
* */
public String toParents(Node node) {
String res = "";
while (node != root) {
if (node.isEnd && (node.org_count > this.min_count)) {
res += "\t" + toRoot(node);
node = node.parent;
} else {
node = node.parent;
}
}
return res;
}
/**
* 查找叶子节点,返回叶子节点组成的链表
* */
public LinkedList<Node> findLeaf() {
LinkedList<Node> leafList = new LinkedList<>();
findLeaf(root, leafList);
return leafList;
}
private void findLeaf(Node node, LinkedList<Node> leafList) {
if (node.childList.size() == 0) { // 找到叶子节点
leafList.add(node);
} else { // 遍历当前节点的子节点
LinkedList<Node> childList = node.childList;
for (Node child : childList) {
findLeaf(child, leafList);
}
}
}
/**
* 从叶子节点向上回溯遍历,遇到isEnd=true节点就获取词路径
* */
public LinkedList<SimpleEntry<Integer, String>> back() {
LinkedList<Node> leafList = findLeaf();
LinkedList<SimpleEntry<Integer, String>> result = new LinkedList<>();
// 从每个叶子节点开始向上遍历
for (Node node : leafList) {
Node curr = node;
int count = curr.count;
StringBuilder sb = new StringBuilder();
sb.append("\t" + count);
while (curr != root) {
if (curr.isEnd) {
String path = toRoot(curr); // 获取词路径
sb.append("\t");
sb.append(path);
}
curr = curr.parent;
}
if (sb.length() > 0)
result.add(new SimpleEntry<Integer, String>(count, sb.toString()));
}
return result;
}
/**
* 从根节点开始,更新每个isEnd=true节点的max_org_count
* */
public void updateMaxOrgCount() {
updateMaxOrgCount(root);
}
private void updateMaxOrgCount(Node node) {
if (!node.isEnd) { // 如果当前节点isEnd=false,则递归遍历子节点
LinkedList<Node> childList = node.childList;
for (Node child : childList) {
updateMaxOrgCount(child);
}
} else { // 如果当前节点的isEnd=true
Node parent_isEnd = getParent(node); // 获取当前节点的isEnd=true父节点
if (parent_isEnd == null) { // 如果不存在这样的父节点,则当前节点的org_count即为max_org_count
node.max_org_count = node.org_count;
} else { // 如果存在这样的父节点,则将当前节点的org_count与父节点的max_org_count比较,更新当前节点的max_org_count
node.max_org_count = node.org_count > parent_isEnd.max_org_count ? node.org_count : parent_isEnd.max_org_count;
}
LinkedList<Node> childList = node.childList;
for (Node child : childList) {
updateMaxOrgCount(child);
}
}
}
/**
* 获取当前节点的上一个isEnd=true节点
* */
public Node getParent(Node node) {
Node parent = node.parent; // 获取当前节点的父节点
while (parent != null && !parent.isEnd) { // 父节点存在,且父节点的isEnd=false,则向上遍历
parent = parent.parent;
}
if (parent == null) { // 最终不存在isEnd=true的父节点
return null;
} else { // 存在isEnd=true的父节点
return parent;
}
}
/**
* trie树剪枝
* */
public void cart(int min_count) {
this.min_count = min_count;
cart(root);
}
private void cart(Node node) {
LinkedList<Node> childList = node.childList; // 获取当前节点的子节点
for (Node child : childList) { // 依次遍历每个子节点,如果子节点isEnd==true && max_org_count < min_count, 则设置isEnd=false
if (child.isEnd && child.max_org_count < this.min_count) {
child.isEnd = false;
}
cart(child);
}
}
//
// public static void main(String[] args) {
// TrieTree tree = new TrieTree();
// tree.insert(new StringBuilder("博士生导师").reverse().toString());
// tree.insert(new StringBuilder("硕士生导师").reverse().toString());
// tree.insert(new StringBuilder("硕士生导师").reverse().toString());
// tree.insert(new StringBuilder("博士硕士生导师").reverse().toString());
// tree.insert(new StringBuilder("导师").reverse().toString());
// tree.insert(new StringBuilder("导师").reverse().toString());
// tree.insert(new StringBuilder("导师").reverse().toString());
// tree.insert(new StringBuilder("老师").reverse().toString());
// tree.insert(new StringBuilder("高级工程师").reverse().toString());
//
// tree.updateMaxOrgCount();
//
// System.out.println("***************************");
// LinkedList<SimpleEntry<Integer, String>> paths = tree.level();
// for (SimpleEntry<Integer, String> entry : paths) {
// System.out.println(entry.getKey()+"\t"+entry.getValue());
// }
// }
}
主要操作类Main.java
package main;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.OutputStreamWriter;
import java.util.AbstractMap.SimpleEntry;
import java.util.Collections;
import java.util.Comparator;
import java.util.LinkedList;
/**
* 程序入口类
* */
public class Main {
public static void main(String[] args) throws Exception {
String src = "";
String des = "";
int min_count = 100; // 树剪枝的阈值
TrieTree tree = treeInit2(src, 1); // 树的初始化, 输入为包含词条列的文件,需要指明目标词条所在的列号(列从0计数)
/**
* 剪枝之前必须更新节点的max_org_count值
* */
tree.updateMaxOrgCount();
// tree.cart(min_count);
// 遍历并排序
LinkedList<SimpleEntry<Integer, String>> results = tree.level();
Collections.sort(results, new Comparator<SimpleEntry<Integer, String>>(){
@Override
public int compare(SimpleEntry<Integer, String> o1, SimpleEntry<Integer, String> o2) {
return o2.getKey() - o1.getKey();
}
});
// 将结果写出到文件
saveToFile(results, des);
System.out.println("process sucessful");
}
/**
* 树的初始化:读取输入文件中的每一行,插入trie树,最后返回树
* @param src : 词条文件,每个词条占一行
* @return tree : 返回初步构造的树
* @throws IOException
* */
public static TrieTree treeInit1(String src) throws IOException {
TrieTree tree = new TrieTree();
/**
* 从文件读取词条,字符串反序后插入TrieTree
* */
FileInputStream fis = new FileInputStream(src);
InputStreamReader isr = new InputStreamReader(fis);
BufferedReader br = new BufferedReader(isr);
String line = br.readLine();
while (line != null) {
if (line.equals("")) { // 过滤掉词条为空字符串
;
} else{
tree.insert(new StringBuilder(line).reverse().toString());
}
line = br.readLine();
}
br.close();
isr.close();
fis.close();
return tree;
}
/**
* 树的初始化:读取输入文件中的每一行,插入trie树,最后返回树
* @param src : 词条文件,每个词条占一行
* @return tree : 返回初步构造的树
* @throws IOException
* */
public static TrieTree treeInit2(String src, int colNum) throws IOException {
TrieTree tree = new TrieTree();
/**
* 从文件读取词条,字符串反序后插入TrieTree
* */
FileInputStream fis = new FileInputStream(src);
InputStreamReader isr = new InputStreamReader(fis);
BufferedReader br = new BufferedReader(isr);
String line = br.readLine();
while (line != null) {
String term = line.split("\\t")[colNum].trim();
if (term.equals("")) { // 过滤掉词条为空字符串
;
} else{
tree.insert(new StringBuilder(term).reverse().toString());
}
line = br.readLine();
}
br.close();
isr.close();
fis.close();
return tree;
}
/**
* 输出便利结果到文件
* @param list : 链表,泛型为SimpleEntry<Integer, String>
* @param des : 目标存储路径
* @return
* @author stevinpan
* @throws IOException
* */
public static void saveToFile(LinkedList<SimpleEntry<Integer, String>> list, String des) throws IOException {
FileOutputStream fos = new FileOutputStream(des);
OutputStreamWriter osw = new OutputStreamWriter(fos);
BufferedWriter bw = new BufferedWriter(osw);
long counter = 1;
for (SimpleEntry<Integer, String> entry : list) {
String kv = entry.getKey()+"\t"+entry.getValue();
bw.write(kv);
bw.write(System.lineSeparator());
System.out.println(counter++);
}
bw.close();
osw.close();
fos.close();
}
}