hanlp关于crf源码结构研究

此篇理解参照了C++版CRF++的源码详解

代码结构设计

具体单元测试代码,我是运行的com/hankcs/book/ch08/DemoCRFNERPlane.java这个

  • CRFNERecognizer
    • 「调用继承自父类CRFTagger的train方法」【训练参数在crf_learn.Option()生成】
      • 调用com/hankcs/hanlp/model/crf/crfpp/Encoder.java的learn方法 【进行特征提取和训练】见1-1
        • com/hankcs/hanlp/model/crf/crfpp/EncoderFeatureIndex.java 继承自com/hankcs/hanlp/model/crf/crfpp/FeatureIndex.java
          • 构造函数 new EncoderFeatureIndex(threadNum)
          • open(templFile, trainFile)
          • openTemplate(filename1) 【读取特征模板文件】
          • com/hankcs/hanlp/model/crf/crfpp/FeatureIndex.java的makeTempls(unigramTempls_, bigramTempls_) unigram对应的是状态特征,bigram对应的是转移特征
          • openTagSet(filename2) 【读取训练文件中的标注集】
        • com/hankcs/hanlp/model/crf/crfpp/TaggerImpl.java
        • 构造函数 TaggerImpl tagger = new TaggerImpl(TaggerImpl.Mode.LEARN)
        • tagger.open(featureIndex). 给tagger对象做一些属性赋值,所有的句子都对应相同的feature_index对象
        • tagger.shrink()
          • feature_index_.buildFeatures(this) -->buildFeatures(TaggerImpl tagger) 构造特征 见1-2
            • buildFeatureFromTempl(feature, unigramTempls_, cur, tagger)
              • featureID = applyRule(tmpl, curPos, tagger) 生成特征字符串
              • getID(featureID) 获取该特征的id,如果不存在该特征,生成新的id
        • switch函数选择使用拟牛顿算法中的LBFGS算法,还是MIRA算法,进行训练,这里主要看LBFGS算法实现(特征函数的期望减去特征函数真实值)
          • 多线程进行梯度计算gradient(expected)
            • buildLattice() 1. 构建无向图 2. 计算节点以及边上的代价
              • rebuildFeatures() //调用该方法初始化节点(Node)和边(Path),并连接
              • 遍历node和edge特征,计算calcCost
            • forwardbackward() 前向后向算法
              • calcAlpha()
              • calcBeta()
              • logsumexp() 取log的操作是为了防止直接取exp溢出
            • calcExpectation 计算每个特征函数的期望
            • 遍历词和边,计算所有特征函数的期望减去特征函数真实值的和
        • 各线程梯度求和
        • 根据L1或L2正则化,更新似然函数值
        • 传入似然函数值和梯度等参数,调用LBFGS算法lbfgs.optimize() 更新𝛼,𝛽
1 encoder.learn详解
1-1 样本的处理以及特征的构造

encoder.learn 是模型训练代码,根据输入的样本和特征模板构造特征,根据传入的algorithm选择指定的训练算法
featureIndex 存放特征值


```java
    /**
     * 训练
     *
     * @param templFile     模板文件
     * @param trainFile     训练文件
     * @param modelFile     模型文件
     * @param textModelFile 是否输出文本形式的模型文件
     * @param maxitr        最大迭代次数
     * @param freq          特征最低频次
     * @param eta           收敛阈值
     * @param C             cost-factor
     * @param threadNum     线程数
     * @param shrinkingSize
     * @param algorithm     训练算法
     * @return
     */
    public boolean learn(String templFile, String trainFile, String modelFile, boolean textModelFile,
                         int maxitr, int freq, double eta, double C, int threadNum, int shrinkingSize,
                         Algorithm algorithm)
    {
        if (eta <= 0)
        {
            System.err.println("eta must be > 0.0");
            return false;
        }
        if (C < 0.0)
        {
            System.err.println("C must be >= 0.0");
            return false;
        }
        if (shrinkingSize < 1)
        {
            System.err.println("shrinkingSize must be >= 1");
            return false;
        }
        if (threadNum <= 0)
        {
            System.err.println("thread must be  > 0");
            return false;
        }
        EncoderFeatureIndex featureIndex = new EncoderFeatureIndex(threadNum);  //所有的特征将存储在feature_index中
        List<TaggerImpl> x = new ArrayList<TaggerImpl>();  //x存放输入的样本,例如:如果做词性标注的话,TaggerTmpl对象存放的是每句话,而x是所有句子
        if (!featureIndex.open(templFile, trainFile))  //打开“模板文件”和“训练文件”
        {
            System.err.println("Fail to open " + templFile + " " + trainFile);
        }
        BufferedReader br = null;
        try
        {
            //开始读取训练样本
            InputStreamReader isr = new InputStreamReader(IOUtil.newInputStream(trainFile), "UTF-8");
            br = new BufferedReader(isr);
            int lineNo = 0;
            //开始训练样本每行的循环
            while (true) 
            {
                TaggerImpl tagger = new TaggerImpl(TaggerImpl.Mode.LEARN); //
                tagger.open(featureIndex);  //做一些属性赋值,包括featureIndex和ysize()赋值给tagger的feature_index_和ysize_ 。所有的句子都对应相同的feature_index
                TaggerImpl.ReadStatus status = tagger.read(br);
                if (status == TaggerImpl.ReadStatus.ERROR)
                {
                    System.err.println("error when reading " + trainFile);
                    return false;
                }
                if (!tagger.empty())
                {
                    //存储节点(单词)和边(相邻词连接)的特征列表(函数中feature变量),并存储在feature_cache中。调用了set_feature_id方法,因此很容易拿到每个句子对应的特征列表
                    if (!tagger.shrink())  
                    {
                        System.err.println("fail to build feature index ");
                        return false;
                    }
                    tagger.setThread_id_(lineNo % threadNum);
                    x.add(tagger);
                }
                else if (status == TaggerImpl.ReadStatus.EOF)
                {
                    break;
                }
                else
                {
                    continue;
                }
                if (++lineNo % 100 == 0)
                {
                    System.out.print(lineNo + ".. ");
                }
            }
            br.close();
        }
        catch (IOException e)
        {
            System.err.println("train file " + trainFile + " does not exist.");
            return false;
        }
        featureIndex.shrink(freq, x);

        double[] alpha = new double[featureIndex.size()];
        Arrays.fill(alpha, 0.0);
        featureIndex.setAlpha_(alpha);

        System.out.println("Number of sentences: " + x.size());
        System.out.println("Number of features:  " + featureIndex.size());
        System.out.println("Number of thread(s): " + threadNum);
        System.out.println("Freq:                " + freq);
        System.out.println("eta:                 " + eta);
        System.out.println("C:                   " + C);
        System.out.println("shrinking size:      " + shrinkingSize);

        switch (algorithm)
        {
            case CRF_L1:
                if (!runCRF(x, featureIndex, alpha, maxitr, C, eta, shrinkingSize, threadNum, true))
                {
                    System.err.println("CRF_L1 execute error");
                    return false;
                }
                break;
            case CRF_L2:
                boolean myflag = runCRF(x, featureIndex, alpha, maxitr, C, eta, shrinkingSize, threadNum, false);
                if (!myflag)
                {
                    System.err.println("CRF_L2 execute error");
                    return false;
                }
                break;
            case MIRA:
                if (!runMIRA(x, featureIndex, alpha, maxitr, C, eta, shrinkingSize, threadNum))
                {
                    System.err.println("MIRA execute error");
                    return false;
                }
                break;
            default:
                break;
        }

        if (!featureIndex.save(modelFile, textModelFile))
        {
            System.err.println("Failed to save model");
        }
        System.out.println("Done!");
        return true;
    }
1-2 TaggerImpl tagger的存储数据的结构
    Mode mode_ = Mode.TEST;
    int vlevel_ = 0;
    int nbest_ = 0;
    int ysize_;
    double cost_;
    double Z_;
    int feature_id_;
    int thread_id_;
    FeatureIndex feature_index_;  //存放训练数据的所有特征
    List<List<String>> x_;  //代表一个句子,外部vector代表多行(多个词),内部vector代表每行的多列,具体的列用char*表示
    List<List<Node>> node_;  //相当于二位数组,node_[i][j]表示一个节点,即:第i个词是第j个label的点。如:“我”这个词是“代词”
    List<Integer> answer_;  //每个词对应的label
    List<Integer> result_;
    String lastError;
    PriorityQueue<QueueElement> agenda_;
    List<List<Double>> penalty_;
    List<List<Integer>> featureCache_;  //缓存特征数据

com/hankcs/hanlp/model/crf/crfpp/FeatureIndex.java 的 buildFeatures 方法

    public boolean buildFeatures(TaggerImpl tagger)
    {
        List<Integer> feature = new ArrayList<Integer>();
        //存放是每个节点或者边对应的特征向量,节点便是node[i][j],边的概念后续会接触,暂时可以忽略
        List<List<Integer>> featureCache = tagger.getFeatureCache_();  
        //做个标记,以后要取该句子的特征,可以从该id的位置取
        tagger.setFeature_id_(featureCache.size());

        //遍历每个词,计算每个词的特征
        for (int cur = 0; cur < tagger.size(); cur++)
        {
            if (!buildFeatureFromTempl(feature, unigramTempls_, cur, tagger))
            {
                return false;
            }
            feature.add(-1);
            featureCache.add(feature);
            feature = new ArrayList<Integer>();
        }

        //遍历每条边,计算每条边的特征
        for (int cur = 1; cur < tagger.size(); cur++)
        {
            if (!buildFeatureFromTempl(feature, bigramTempls_, cur, tagger))
            {
                return false;
            }
            feature.add(-1);
            featureCache.add(feature);
            feature = new ArrayList<Integer>();
        }
        return true;
    }

com/hankcs/hanlp/model/crf/crfpp/FeatureIndex.java的
buildFeatureFromTempl(feature, unigramTempls_, cur, tagger)
bigramTempls_,unigramTempls_初始设置?

    private boolean buildFeatureFromTempl(List<Integer> feature, List<String> templs, int curPos, TaggerImpl tagger)
    {
        for (String tmpl : templs)
        {
            // applyRule函数根据当前词(cur)以及当前的特征(如: %x[-2,0]),生成一个特征,存放在featureID
            String featureID = applyRule(tmpl, curPos, tagger);
            if (featureID == null || featureID.length() == 0)
            {
                System.err.println("format error");
                return false;
            }
            //获取该featureID的id,如果不存在该特征,生成新的id,将该id添加到feature变量中,feature变量作为当前行训练语料构建的所有特征
            int id = getID(featureID);
            if (id != -1)
            {
                feature.add(id);
            }
        }
        return true;
    }

com/hankcs/hanlp/model/crf/crfpp/FeatureIndex.java 的applyRule()

    public String applyRule(String str, int cur, TaggerImpl tagger)
    {
        StringBuilder sb = new StringBuilder();
        for (String tmp : str.split("%x", -1))
        {
            if (tmp.startsWith("U") || tmp.startsWith("B"))
            {
                sb.append(tmp);
            }
            else if (tmp.length() > 0)
            {
                String[] tuple = tmp.split("]");
                String[] idx = tuple[0].replace("[", "").split(",");
                String r = getIndex(idx, cur, tagger);
                if (r != null)
                {
                    sb.append(r);
                }
                if (tuple.length > 1)
                {
                    sb.append(tuple[1]);
                }
            }
        }

        return sb.toString();
    }

代码中预定义好的特征模板

U0:%x[-2,0]
U1:%x[-1,0]
U2:%x[0,0]
U3:%x[1,0]
U4:%x[2,0]
U5:%x[-2,1]
U6:%x[-1,1]
U7:%x[0,1]
U8:%x[1,1]
U9:%x[2,1]
UA:%x[-2,1]%x[-1,1]
UB:%x[-1,1]%x[0,1]
UC:%x[0,1]%x[1,1]
UD:%x[1,1]%x[2,1]
UE:%x[2,1]%x[3,1]
B

com/hankcs/hanlp/model/crf/crfpp/EncoderFeatureIndex.java的 getID(String key)

    public int getID(String key)
    {
        int k = dic_.get(key);
        if (k == -1)
        {
            dic_.put(key, maxid_);
            frequency.append(1);
            int n = maxid_;
            if (key.charAt(0) == 'U')
            {
                maxid_ += y_.size();
            }
            else
            {
                bId = n;
                maxid_ += y_.size() * y_.size();
            }
            return n;
        }
        else
        {
            int cid = continuousId(k);
            int oldVal = frequency.get(cid);
            frequency.set(cid, oldVal + 1);
            return k;
        }
    }
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
HanLP是一系列模型算法组成的NLP工具包,由大快搜索主导并完全开源,目标是普及自然语言处理在生产环境中的应用.HanLP具备功能完善,性能高效,架构清晰,语料时新,可自定义的特点。 HanLP提供下列功能: 中文分词 HMM-两字组(速度与精度最佳平衡;一百兆内存) 最短路分词,N-最短路分词 由字构词(侧重精度,全世界最大语料库,可识别新词;适合NLP任务) 感知机分词,CRF分词 词典分词(侧重速度,每秒数千万字符;省内存) 极速词典分词 所有分词器都支持: 索引全切分模式 用户自定义词典 兼容繁体中文 训练用户自己的领域模型 词性标注 HMM词性标注(速度快) 感知机词性标注,CRF词性标注(精度高) 命名实体识别 基于HMM角色标注的命名实体识别(速度快) 中国人名识别,音译人名识别,日本人名识别,地名识别,实体机构名识别 基于线性模型的命名实体识别(精度高) 感知机命名实体识别,CRF命名实体识别 关键词提取 TextRank关键词提取 自动摘要 TextRank自动摘要 短语提取 基于互信息和左右信息熵的短语提取 拼音转换 多音字,声母,韵母,声调 简繁转换 简繁分歧词(简体,繁体,台湾正体,香港繁体) 文本推荐 语义推荐,拼音推荐,字词推荐 依存句法分析 基于神经网络的高性能依存句法分析器 基于ArcEager转移系统的柱搜索依存句法分析器 文本分类 情感分析 文本聚类 KMeans,Repeated Bisection,自动推断聚类数目k word2vec 词向量训练,加载,词语相似度计算,语义运算,查询,KMEANS聚类 文档语义相似度计算 语料库工具 部分默认模型训练自小型语料库,鼓励用户自行训练。模块所有提供训练接口,语料可参考98年人民日报语料库。 在提供丰富功能的同时,HanLP内部模块坚持低耦合,模型坚持惰性加载,服务坚持静态提供,词典坚持明文发布,使用非常方便。默认模型训练自全世界最大规模的中文语料库,同时自带一些语料处理工具,帮助用户训练自己的模型

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值