代码结构设计
具体单元测试代码,我是运行的com/hankcs/book/ch08/DemoCRFNERPlane.java这个
- CRFNERecognizer
- 「调用继承自父类CRFTagger的train方法」【训练参数在crf_learn.Option()生成】
- 调用com/hankcs/hanlp/model/crf/crfpp/Encoder.java的learn方法 【进行特征提取和训练】见1-1
- com/hankcs/hanlp/model/crf/crfpp/EncoderFeatureIndex.java 继承自com/hankcs/hanlp/model/crf/crfpp/FeatureIndex.java
- 构造函数 new EncoderFeatureIndex(threadNum)
- open(templFile, trainFile)
- openTemplate(filename1) 【读取特征模板文件】
- com/hankcs/hanlp/model/crf/crfpp/FeatureIndex.java的makeTempls(unigramTempls_, bigramTempls_) unigram对应的是状态特征,bigram对应的是转移特征
- openTagSet(filename2) 【读取训练文件中的标注集】
- com/hankcs/hanlp/model/crf/crfpp/TaggerImpl.java
- 构造函数 TaggerImpl tagger = new TaggerImpl(TaggerImpl.Mode.LEARN)
- tagger.open(featureIndex). 给tagger对象做一些属性赋值,所有的句子都对应相同的feature_index对象
- tagger.shrink()
- feature_index_.buildFeatures(this) -->buildFeatures(TaggerImpl tagger) 构造特征 见1-2
- buildFeatureFromTempl(feature, unigramTempls_, cur, tagger)
- featureID = applyRule(tmpl, curPos, tagger) 生成特征字符串
- getID(featureID) 获取该特征的id,如果不存在该特征,生成新的id
- buildFeatureFromTempl(feature, unigramTempls_, cur, tagger)
- feature_index_.buildFeatures(this) -->buildFeatures(TaggerImpl tagger) 构造特征 见1-2
- switch函数选择使用拟牛顿算法中的LBFGS算法,还是MIRA算法,进行训练,这里主要看LBFGS算法实现(特征函数的期望减去特征函数真实值)
- 多线程进行梯度计算gradient(expected)
- buildLattice() 1. 构建无向图 2. 计算节点以及边上的代价
- rebuildFeatures() //调用该方法初始化节点(Node)和边(Path),并连接
- 遍历node和edge特征,计算calcCost
- forwardbackward() 前向后向算法
- calcAlpha()
- calcBeta()
- logsumexp() 取log的操作是为了防止直接取exp溢出
- calcExpectation 计算每个特征函数的期望
- 遍历词和边,计算所有特征函数的期望减去特征函数真实值的和
- buildLattice() 1. 构建无向图 2. 计算节点以及边上的代价
- 多线程进行梯度计算gradient(expected)
- 各线程梯度求和
- 根据L1或L2正则化,更新似然函数值
- 传入似然函数值和梯度等参数,调用LBFGS算法lbfgs.optimize() 更新𝛼,𝛽
- com/hankcs/hanlp/model/crf/crfpp/EncoderFeatureIndex.java 继承自com/hankcs/hanlp/model/crf/crfpp/FeatureIndex.java
- 调用com/hankcs/hanlp/model/crf/crfpp/Encoder.java的learn方法 【进行特征提取和训练】见1-1
- 「调用继承自父类CRFTagger的train方法」【训练参数在crf_learn.Option()生成】
1 encoder.learn详解
1-1 样本的处理以及特征的构造
encoder.learn 是模型训练代码,根据输入的样本和特征模板构造特征,根据传入的algorithm选择指定的训练算法
featureIndex 存放特征值
```java
/**
* 训练
*
* @param templFile 模板文件
* @param trainFile 训练文件
* @param modelFile 模型文件
* @param textModelFile 是否输出文本形式的模型文件
* @param maxitr 最大迭代次数
* @param freq 特征最低频次
* @param eta 收敛阈值
* @param C cost-factor
* @param threadNum 线程数
* @param shrinkingSize
* @param algorithm 训练算法
* @return
*/
public boolean learn(String templFile, String trainFile, String modelFile, boolean textModelFile,
int maxitr, int freq, double eta, double C, int threadNum, int shrinkingSize,
Algorithm algorithm)
{
if (eta <= 0)
{
System.err.println("eta must be > 0.0");
return false;
}
if (C < 0.0)
{
System.err.println("C must be >= 0.0");
return false;
}
if (shrinkingSize < 1)
{
System.err.println("shrinkingSize must be >= 1");
return false;
}
if (threadNum <= 0)
{
System.err.println("thread must be > 0");
return false;
}
EncoderFeatureIndex featureIndex = new EncoderFeatureIndex(threadNum); //所有的特征将存储在feature_index中
List<TaggerImpl> x = new ArrayList<TaggerImpl>(); //x存放输入的样本,例如:如果做词性标注的话,TaggerTmpl对象存放的是每句话,而x是所有句子
if (!featureIndex.open(templFile, trainFile)) //打开“模板文件”和“训练文件”
{
System.err.println("Fail to open " + templFile + " " + trainFile);
}
BufferedReader br = null;
try
{
//开始读取训练样本
InputStreamReader isr = new InputStreamReader(IOUtil.newInputStream(trainFile), "UTF-8");
br = new BufferedReader(isr);
int lineNo = 0;
//开始训练样本每行的循环
while (true)
{
TaggerImpl tagger = new TaggerImpl(TaggerImpl.Mode.LEARN); //
tagger.open(featureIndex); //做一些属性赋值,包括featureIndex和ysize()赋值给tagger的feature_index_和ysize_ 。所有的句子都对应相同的feature_index
TaggerImpl.ReadStatus status = tagger.read(br);
if (status == TaggerImpl.ReadStatus.ERROR)
{
System.err.println("error when reading " + trainFile);
return false;
}
if (!tagger.empty())
{
//存储节点(单词)和边(相邻词连接)的特征列表(函数中feature变量),并存储在feature_cache中。调用了set_feature_id方法,因此很容易拿到每个句子对应的特征列表
if (!tagger.shrink())
{
System.err.println("fail to build feature index ");
return false;
}
tagger.setThread_id_(lineNo % threadNum);
x.add(tagger);
}
else if (status == TaggerImpl.ReadStatus.EOF)
{
break;
}
else
{
continue;
}
if (++lineNo % 100 == 0)
{
System.out.print(lineNo + ".. ");
}
}
br.close();
}
catch (IOException e)
{
System.err.println("train file " + trainFile + " does not exist.");
return false;
}
featureIndex.shrink(freq, x);
double[] alpha = new double[featureIndex.size()];
Arrays.fill(alpha, 0.0);
featureIndex.setAlpha_(alpha);
System.out.println("Number of sentences: " + x.size());
System.out.println("Number of features: " + featureIndex.size());
S