字典树的构建

摘要

  该部分主要讲述基于Java语言构建字典树,包括字典树的剪枝与遍历操作。字典树原理不再赘述,代码实现部分如下

实现部分

工具类Tools.java,主要实现对大数据集的采样,以及对数据规模的统计

package main;

import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.OutputStreamWriter;
import java.util.HashSet;

public class Tools {

    /**
     * 从大数据集中拆分出小样本, size指定获取的行数
     * */
    public static void getSample(String src, String des, long size) throws IOException {

        FileInputStream fis = new FileInputStream(src);
        InputStreamReader isr = new InputStreamReader(fis);
        BufferedReader br = new BufferedReader(isr);

        FileOutputStream fos = new FileOutputStream(des);
        OutputStreamWriter osw = new OutputStreamWriter(fos);
        BufferedWriter bw = new BufferedWriter(osw);

        for (int row = 0; row < size; row++) {
            String line = br.readLine();
            bw.write(line);
            bw.write(System.lineSeparator());
        }

        bw.close();
        osw.close();
        fos.close();

        br.close();
        isr.close();
        fis.close();

        System.out.println("get sample successful, size is: " + size);
    }

    /**
     * 提取数据文件的指定列(列从0开始计数),按行写入到新文件
     * */
    public static void getColumn(String src, String des, int colNum) throws IOException {
        FileInputStream fis = new FileInputStream(src);
        InputStreamReader isr = new InputStreamReader(fis);
        BufferedReader br = new BufferedReader(isr);

        FileOutputStream fos = new FileOutputStream(des);
        OutputStreamWriter osw = new OutputStreamWriter(fos);
        BufferedWriter bw = new BufferedWriter(osw);

        // 依次读取文件每一行
        String line = br.readLine();
        while (line != null) {
            String column = line.split("\\t")[colNum].trim();   // 提取出指定列
            if (column.equals("")) {                            // 过滤无效空记录
                ;
            } else {
                bw.write(column);
                bw.write(System.lineSeparator());               // 每个记录占一行              
            }

            line = br.readLine();
        }

        bw.close();
        osw.close();
        fos.close();

        br.close();
        isr.close();
        fis.close();

        System.out.println("extract column to file successful, cloumnNum is: " + colNum);
    }

    /**
     * 数据规模统计
     * */
    public static void getSize(String src) throws IOException {
        HashSet<String> chars_set = new HashSet<String>();      // 存储单个字符
        HashSet<String> terms_set = new HashSet<>();            // 存储所有纪录

        FileInputStream fis = new FileInputStream(src);
        InputStreamReader isr = new InputStreamReader(fis);
        BufferedReader br = new BufferedReader(isr);

        String line = br.readLine();
        while (line != null) {
            terms_set.add(line);

            String[] arrs = line.split("");
            for (String arr : arrs)
                chars_set.add(arr);

            line = br.readLine();
        }

        System.out.println("chars_set.size():" + chars_set.size());
        System.out.println("terms_set.size():" + terms_set.size());
    }

    public static void main(String[] args) throws IOException {
        String path1 = "";
        String path2 = "";

        // 抽取第一列
        getColumn(path1, path2, 1);
        getSize(path2);
    }
}

节点类Node.java,树节点

package main;

import java.util.LinkedList;

/**
 * 树节点类
 * */

public class Node {
    char val;       // 保存当前节点的字符
    int count;      // 统计经过当前节点的字符串数目
    boolean isEnd;  // 标志当前节点是否是一个词的末尾

    Node parent;    // 存储当前节点的父节点
    LinkedList<Node> childList; // 存储当前节点的直接子节点

    int org_count;  // 当前节点的原始频次
    int max_org_count;  // 当前节点路径上的最大org_count

    /**
     * 带参构造方法,构造含有指定字符val的节点,并指明新建节点的父节点
     * */
    public Node(char val, Node parent) {
        this.val = val;
        this.count = 1;
        this.isEnd = false;

        this.parent = parent;
        this.childList = new LinkedList<Node>();

        this.org_count = 0;
        this.max_org_count = 0;
    }

    /**
     * 无参构造方法,初始化val=' ', count=0, isEnd=false, parent=null, childList=new List()
     * */
    public Node() {
        this(' ', null);
    }


    /**
     * 根据指定的字符,获取当前节点的子节点。子节点不存在则返回null
     * */
    public Node getNode(char val) {
        // 依次遍历链表
        for (Node child : childList) {
            if (child.val == val)
                return child;
            else
                continue;
        }

        return null;
    }

    @Override
    public String toString() {
        return this.val + ":" + this.count;
    }
}

字典树类Trietree.java

package main;

import java.util.AbstractMap.SimpleEntry;
import java.util.LinkedList;
import java.util.Map.Entry;

/**
 * 定义Trie树,以Node作为节点类
 * @author stevinpan
 * */

public class TrieTree {
    public Node root = null;        // 树的根节点,不存储字符信息
    private int min_count = 0;  // 剪枝的最小阈值

    /**
     * 构造方法,初始化根节点
     * */
    public TrieTree() {
        root = new Node();
    }

    /**
     * 判断当前树是否存在指定字符串
     * */
    public boolean isExist(String word) {
        Node curr = root;       // 获取根结点作为当前指针

        if (curr == null || word == null)
            return false;

        // 依次遍历当前字符串的每个串
        for (int index = 0; index < word.length(); index++) {
            Node next = curr.getNode(word.charAt(index));       // 获取包含当前字符的子节点

            if (next != null) {     // 子节点存在,指针后移
                curr = next;
            } else {                // 子节点不存在,直接返回
                return false;
            }
        }

        // 根据curr节点的标志位,判断是否是单词结尾
        if (curr.isEnd) {
            System.out.println(curr.count);
            return true;
        }
        return false;
    }

    /**
     * 插入字符串,相同字符串也要插入
     * 插入成功返回true
     * 
     * */
    public boolean insert(String word) {
        if (word == null)
            return false;

        if (root == null)
            root = new Node();

        Node curr = root;           // 获取根结点作为当前指针

        // 依次遍历字符串,树路径上所有节点count++
        for (int index = 0; index < word.length(); index++) {
            Node next = curr.getNode(word.charAt(index));       // 获取包含当前字符的子节点

            if (next != null) {     // 存在包含当前字符的子节点,子节点count++,指针后移
                next.count++;
                curr = next;
            } else {                // 不存在包含当前字符的节点,则创建新节点,指针后移
                next = new Node(word.charAt(index), curr);
                curr.childList.add(next);
                curr = next;
            }
        }
        // 标注出单词结尾
        curr.isEnd = true;
        curr.org_count++;

        return true;
    }

    /**
     * trie树的层次遍历,每行一个职业名称
     * */
    public void printAll() {
        printAll(root);
    }
    private void printAll(Node node) {
        LinkedList<Node> childList = node.childList;        // 获取当前节点的子节点
        for (Node child : childList) {                      // 依次遍历每个子节点,如果子节点isEnd==true,则反向输出该节点
            if (child.isEnd) {
                System.out.println(child.count+"\t"+toRoot(child));
            }
            printAll(child);
        }
    }

    /**
     * 层次遍历,每行按层次输出
     * 从根节点开始,遇到isEnd=true则输出向上的路径,层次遍历
     * */
    public LinkedList<SimpleEntry<Integer, String>> level() {
        LinkedList<SimpleEntry<Integer, String>> result = new LinkedList<SimpleEntry<Integer, String>>();
        level(root, result);

        return result;
    }
    private void level(Node node, LinkedList<SimpleEntry<Integer, String>>  results) {
        // 遇到isEnd=true节点,向上回溯
        if (node.isEnd) {
            // 获取当前节点向上的全路径
            String full_path = toRoot(node);

            // 获取当前节点各个isEnd祖先节点全路径
            String parent_paths = toParents(node.parent);

            // 输出到List
            results.add(new SimpleEntry<Integer, String>(node.count, node.org_count + "\t" + full_path + parent_paths));

            // 递归访问子节点
            for (Node child : node.childList) {
                level(child, results);
            }

        } else {    // 依次遍历子节点
            for (Node child : node.childList) {
                level(child, results);
            }
        }
    }

    /**
     * 根据当前节点,向上反向输出,直至root节点
     * */
    public String toRoot(Node node) {
        if (node != root) {
            return node.val + toRoot(node.parent);
        } else {
            return "";
        }
    }

    /**
     * 根据当前节点,获取所有祖先节点的路径,以"\t"分隔输出
     * */
    public String toParents(Node node) {
        String res = "";
        while (node != root) {
            if (node.isEnd && (node.org_count > this.min_count)) {
                res += "\t" + toRoot(node);
                node = node.parent;
            } else {
                node = node.parent;
            }
        }

        return res;
    }

    /**
     * 查找叶子节点,返回叶子节点组成的链表
     * */
    public LinkedList<Node> findLeaf() {
        LinkedList<Node> leafList = new LinkedList<>();
        findLeaf(root, leafList);

        return leafList;
    }
    private void findLeaf(Node node, LinkedList<Node> leafList) {
        if (node.childList.size() == 0) {   // 找到叶子节点
            leafList.add(node);
        } else {                            // 遍历当前节点的子节点
            LinkedList<Node> childList = node.childList;
            for (Node child : childList) {
                findLeaf(child, leafList);
            }
        }
    }

    /**
     * 从叶子节点向上回溯遍历,遇到isEnd=true节点就获取词路径
     * */
    public LinkedList<SimpleEntry<Integer, String>> back() {
        LinkedList<Node> leafList = findLeaf();
        LinkedList<SimpleEntry<Integer, String>> result = new LinkedList<>();

        // 从每个叶子节点开始向上遍历
        for (Node node : leafList) {
            Node curr = node;
            int count = curr.count;
            StringBuilder sb = new StringBuilder();

            sb.append("\t" + count);
            while (curr != root) {
                if (curr.isEnd) {
                    String path = toRoot(curr); // 获取词路径
                    sb.append("\t");
                    sb.append(path);                        
                }

                curr = curr.parent;
            }

            if (sb.length() > 0)
                result.add(new SimpleEntry<Integer, String>(count, sb.toString()));
        }

        return result;
    }

    /**
     * 从根节点开始,更新每个isEnd=true节点的max_org_count
     * */
    public void updateMaxOrgCount() {
        updateMaxOrgCount(root);
    }
    private void updateMaxOrgCount(Node node) {
        if (!node.isEnd) {      // 如果当前节点isEnd=false,则递归遍历子节点
            LinkedList<Node> childList = node.childList;
            for (Node child : childList) {
                updateMaxOrgCount(child);
            }
        } else {                // 如果当前节点的isEnd=true
            Node parent_isEnd = getParent(node);    // 获取当前节点的isEnd=true父节点
            if (parent_isEnd == null) {             // 如果不存在这样的父节点,则当前节点的org_count即为max_org_count
                node.max_org_count = node.org_count;
            } else {                                // 如果存在这样的父节点,则将当前节点的org_count与父节点的max_org_count比较,更新当前节点的max_org_count
                node.max_org_count = node.org_count > parent_isEnd.max_org_count ? node.org_count : parent_isEnd.max_org_count;
            }

            LinkedList<Node> childList = node.childList;
            for (Node child : childList) {
                updateMaxOrgCount(child);
            }
        }
    }

    /**
     * 获取当前节点的上一个isEnd=true节点
     * */
    public Node getParent(Node node) {
        Node parent = node.parent;          // 获取当前节点的父节点

        while (parent != null && !parent.isEnd) {   // 父节点存在,且父节点的isEnd=false,则向上遍历
            parent = parent.parent;
        }

        if (parent == null) {               // 最终不存在isEnd=true的父节点
            return null;
        } else {                            // 存在isEnd=true的父节点
            return parent;
        }
    }

    /**
     * trie树剪枝
     * */
    public void cart(int min_count) {
        this.min_count = min_count;
        cart(root);
    }
    private void cart(Node node) {
        LinkedList<Node> childList = node.childList;        // 获取当前节点的子节点
        for (Node child : childList) {                      // 依次遍历每个子节点,如果子节点isEnd==true && max_org_count < min_count, 则设置isEnd=false
            if (child.isEnd && child.max_org_count < this.min_count) {
                child.isEnd = false;
            }
            cart(child);
        }
    }
//  
//  public static void main(String[] args) {
//      TrieTree tree = new TrieTree();
//      tree.insert(new StringBuilder("博士生导师").reverse().toString());
//      tree.insert(new StringBuilder("硕士生导师").reverse().toString());
//      tree.insert(new StringBuilder("硕士生导师").reverse().toString());
//      tree.insert(new StringBuilder("博士硕士生导师").reverse().toString());
//      tree.insert(new StringBuilder("导师").reverse().toString());
//      tree.insert(new StringBuilder("导师").reverse().toString());
//      tree.insert(new StringBuilder("导师").reverse().toString());
//      tree.insert(new StringBuilder("老师").reverse().toString());
//      tree.insert(new StringBuilder("高级工程师").reverse().toString());
//      
//      tree.updateMaxOrgCount();
//      
//      System.out.println("***************************");
//      LinkedList<SimpleEntry<Integer, String>> paths = tree.level();
//      for (SimpleEntry<Integer, String> entry : paths) {
//          System.out.println(entry.getKey()+"\t"+entry.getValue());
//      }
//  }

}

主要操作类Main.java

package main;

import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.OutputStreamWriter;
import java.util.AbstractMap.SimpleEntry;
import java.util.Collections;
import java.util.Comparator;
import java.util.LinkedList;


/**
 * 程序入口类
 * */
public class Main {
    public static void main(String[] args) throws Exception {
        String src = "";
        String des = "";
        int min_count = 100;            // 树剪枝的阈值

        TrieTree tree = treeInit2(src, 1);  // 树的初始化, 输入为包含词条列的文件,需要指明目标词条所在的列号(列从0计数)

        /**
         * 剪枝之前必须更新节点的max_org_count值
         * */
        tree.updateMaxOrgCount();
//      tree.cart(min_count);

        // 遍历并排序
        LinkedList<SimpleEntry<Integer, String>> results = tree.level();
        Collections.sort(results, new Comparator<SimpleEntry<Integer, String>>(){
            @Override
            public int compare(SimpleEntry<Integer, String> o1, SimpleEntry<Integer, String> o2) {
                return o2.getKey() - o1.getKey();
            }
        });

        // 将结果写出到文件
        saveToFile(results, des);
        System.out.println("process sucessful");
    }

    /**
     * 树的初始化:读取输入文件中的每一行,插入trie树,最后返回树
     * @param src : 词条文件,每个词条占一行
     * @return tree : 返回初步构造的树
     * @throws IOException 
     * */
    public static TrieTree treeInit1(String src) throws IOException {
        TrieTree tree = new TrieTree();

        /**
         * 从文件读取词条,字符串反序后插入TrieTree
         * */
        FileInputStream fis = new FileInputStream(src);
        InputStreamReader isr = new InputStreamReader(fis);
        BufferedReader br = new BufferedReader(isr);

        String line = br.readLine();
        while (line != null) {
            if (line.equals("")) {  // 过滤掉词条为空字符串
                ;
            } else{
                tree.insert(new StringBuilder(line).reverse().toString());
            }
            line = br.readLine();               
        }

        br.close();
        isr.close();
        fis.close();

        return tree;
    }

    /**
     * 树的初始化:读取输入文件中的每一行,插入trie树,最后返回树
     * @param src : 词条文件,每个词条占一行
     * @return tree : 返回初步构造的树
     * @throws IOException 
     * */
    public static TrieTree treeInit2(String src, int colNum) throws IOException {
        TrieTree tree = new TrieTree();

        /**
         * 从文件读取词条,字符串反序后插入TrieTree
         * */
        FileInputStream fis = new FileInputStream(src);
        InputStreamReader isr = new InputStreamReader(fis);
        BufferedReader br = new BufferedReader(isr);

        String line = br.readLine();
        while (line != null) {
            String term = line.split("\\t")[colNum].trim();
            if (term.equals("")) {  // 过滤掉词条为空字符串
                ;
            } else{
                tree.insert(new StringBuilder(term).reverse().toString());
            }
            line = br.readLine();
        }

        br.close();
        isr.close();
        fis.close();

        return tree;
    }

    /**
     * 输出便利结果到文件
     * @param list : 链表,泛型为SimpleEntry<Integer, String>
     * @param des : 目标存储路径
     * @return
     * @author stevinpan
     * @throws IOException 
     * */
    public static void saveToFile(LinkedList<SimpleEntry<Integer, String>> list, String des) throws IOException {
        FileOutputStream fos = new FileOutputStream(des);
        OutputStreamWriter osw = new OutputStreamWriter(fos);
        BufferedWriter bw = new BufferedWriter(osw);

        long counter = 1;
        for (SimpleEntry<Integer, String> entry : list) {
            String kv = entry.getKey()+"\t"+entry.getValue();
            bw.write(kv);
            bw.write(System.lineSeparator());

            System.out.println(counter++);
        }

        bw.close();
        osw.close();
        fos.close();
    }
}

工程文件地址为:https://github.com/panshan/TrieTree.git

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值