代码结构设计
具体单元测试代码,我是运行的com/hankcs/book/ch08/DemoCRFNERPlane.java这个
- CRFNERecognizer
- 「调用继承自父类CRFTagger的train方法」【训练参数在crf_learn.Option()生成】
- 调用com/hankcs/hanlp/model/crf/crfpp/Encoder.java的learn方法 【进行特征提取和训练】见1-1
- com/hankcs/hanlp/model/crf/crfpp/EncoderFeatureIndex.java 继承自com/hankcs/hanlp/model/crf/crfpp/FeatureIndex.java
- 构造函数 new EncoderFeatureIndex(threadNum)
- open(templFile, trainFile)
- openTemplate(filename1) 【读取特征模板文件】
- com/hankcs/hanlp/model/crf/crfpp/FeatureIndex.java的makeTempls(unigramTempls_, bigramTempls_) unigram对应的是状态特征,bigram对应的是转移特征
- openTagSet(filename2) 【读取训练文件中的标注集】
- com/hankcs/hanlp/model/crf/crfpp/TaggerImpl.java
- 构造函数 TaggerImpl tagger = new TaggerImpl(TaggerImpl.Mode.LEARN)
- tagger.open(featureIndex). 给tagger对象做一些属性赋值,所有的句子都对应相同的feature_index对象
- tagger.shrink()
- feature_index_.buildFeatures(this) -->buildFeatures(TaggerImpl tagger) 构造特征 见1-2
- buildFeatureFromTempl(feature, unigramTempls_, cur, tagger)
- featureID = applyRule(tmpl, curPos, tagger) 生成特征字符串
- getID(featureID) 获取该特征的id,如果不存在该特征,生成新的id
- buildFeatureFromTempl(feature, unigramTempls_, cur, tagger)
- feature_index_.buildFeatures(this) -->buildFeatures(TaggerImpl tagger) 构造特征 见1-2
- switch函数选择使用拟牛顿算法中的LBFGS算法,还是MIRA算法,进行训练,这里主要看LBFGS算法实现(特征函数的期望减去特征函数真实值)
- 多线程进行梯度计算gradient(expected)
- buildLattice() 1. 构建无向图 2. 计算节点以及边上的代价
- rebuildFeatures() //调用该方法初始化节点(Node)和边(Path),并连接
- 遍历node和edge特征,计算calcCost
- forwardbackward() 前向后向算法
- calcAlpha()
- calcBeta()
- logsumexp() 取log的操作是为了防止直接取exp溢出
- calcExpectation 计算每个特征函数的期望
- 遍历词和边,计算所有特征函数的期望减去特征函数真实值的和
- buildLattice() 1. 构建无向图 2. 计算节点以及边上的代价
- 多线程进行梯度计算gradient(expected)
- 各线程梯度求和
- 根据L1或L2正则化,更新似然函数值
- 传入似然函数值和梯度等参数,调用LBFGS算法lbfgs.optimize() 更新𝛼,𝛽
- com/hankcs/hanlp/model/crf/crfpp/EncoderFeatureIndex.java 继承自com/hankcs/hanlp/model/crf/crfpp/FeatureIndex.java
- 调用com/hankcs/hanlp/model/crf/crfpp/Encoder.java的learn方法 【进行特征提取和训练】见1-1
- 「调用继承自父类CRFTagger的train方法」【训练参数在crf_learn.Option()生成】
1 encoder.learn详解
1-1 样本的处理以及特征的构造
encoder.learn 是模型训练代码,根据输入的样本和特征模板构造特征,根据传入的algorithm选择指定的训练算法
featureIndex 存放特征值
```java
/**
* 训练
*
* @param templFile 模板文件
* @param trainFile 训练文件
* @param modelFile 模型文件
* @param textModelFile 是否输出文本形式的模型文件
* @param maxitr 最大迭代次数
* @param freq 特征最低频次
* @param eta 收敛阈值
* @param C cost-factor
* @param threadNum 线程数
* @param shrinkingSize
* @param algorithm 训练算法
* @return
*/
public boolean learn(String templFile, String trainFile, String modelFile, boolean textModelFile,
int maxitr, int freq, double eta, double C, int threadNum, int shrinkingSize,
Algorithm algorithm)
{
if (eta <= 0)
{
System.err.println("eta must be > 0.0");
return false;
}
if (C < 0.0)
{
System.err.println("C must be >= 0.0");
return false;
}
if (shrinkingSize < 1)
{
System.err.println("shrinkingSize must be >= 1");
return false;
}
if (threadNum <= 0)
{
System.err.println("thread must be > 0");
return false;
}
EncoderFeatureIndex featureIndex = new EncoderFeatureIndex(threadNum); //所有的特征将存储在feature_index中
List<TaggerImpl> x = new ArrayList<TaggerImpl>(); //x存放输入的样本,例如:如果做词性标注的话,TaggerTmpl对象存放的是每句话,而x是所有句子
if (!featureIndex.open(templFile, trainFile)) //打开“模板文件”和“训练文件”
{
System.err.println("Fail to open " + templFile + " " + trainFile);
}
BufferedReader br = null;
try
{
//开始读取训练样本
InputStreamReader isr = new InputStreamReader(IOUtil.newInputStream(trainFile), "UTF-8");
br = new BufferedReader(isr);
int lineNo = 0;
//开始训练样本每行的循环
while (true)
{
TaggerImpl tagger = new TaggerImpl(TaggerImpl.Mode.LEARN); //
tagger.open(featureIndex); //做一些属性赋值,包括featureIndex和ysize()赋值给tagger的feature_index_和ysize_ 。所有的句子都对应相同的feature_index
TaggerImpl.ReadStatus status = tagger.read(br);
if (status == TaggerImpl.ReadStatus.ERROR)
{
System.err.println("error when reading " + trainFile);
return false;
}
if (!tagger.empty())
{
//存储节点(单词)和边(相邻词连接)的特征列表(函数中feature变量),并存储在feature_cache中。调用了set_feature_id方法,因此很容易拿到每个句子对应的特征列表
if (!tagger.shrink())
{
System.err.println("fail to build feature index ");
return false;
}
tagger.setThread_id_(lineNo % threadNum);
x.add(tagger);
}
else if (status == TaggerImpl.ReadStatus.EOF)
{
break;
}
else
{
continue;
}
if (++lineNo % 100 == 0)
{
System.out.print(lineNo + ".. ");
}
}
br.close();
}
catch (IOException e)
{
System.err.println("train file " + trainFile + " does not exist.");
return false;
}
featureIndex.shrink(freq, x);
double[] alpha = new double[featureIndex.size()];
Arrays.fill(alpha, 0.0);
featureIndex.setAlpha_(alpha);
System.out.println("Number of sentences: " + x.size());
System.out.println("Number of features: " + featureIndex.size());
System.out.println("Number of thread(s): " + threadNum);
System.out.println("Freq: " + freq);
System.out.println("eta: " + eta);
System.out.println("C: " + C);
System.out.println("shrinking size: " + shrinkingSize);
switch (algorithm)
{
case CRF_L1:
if (!runCRF(x, featureIndex, alpha, maxitr, C, eta, shrinkingSize, threadNum, true))
{
System.err.println("CRF_L1 execute error");
return false;
}
break;
case CRF_L2:
boolean myflag = runCRF(x, featureIndex, alpha, maxitr, C, eta, shrinkingSize, threadNum, false);
if (!myflag)
{
System.err.println("CRF_L2 execute error");
return false;
}
break;
case MIRA:
if (!runMIRA(x, featureIndex, alpha, maxitr, C, eta, shrinkingSize, threadNum))
{
System.err.println("MIRA execute error");
return false;
}
break;
default:
break;
}
if (!featureIndex.save(modelFile, textModelFile))
{
System.err.println("Failed to save model");
}
System.out.println("Done!");
return true;
}
1-2 TaggerImpl tagger的存储数据的结构
Mode mode_ = Mode.TEST;
int vlevel_ = 0;
int nbest_ = 0;
int ysize_;
double cost_;
double Z_;
int feature_id_;
int thread_id_;
FeatureIndex feature_index_; //存放训练数据的所有特征
List<List<String>> x_; //代表一个句子,外部vector代表多行(多个词),内部vector代表每行的多列,具体的列用char*表示
List<List<Node>> node_; //相当于二位数组,node_[i][j]表示一个节点,即:第i个词是第j个label的点。如:“我”这个词是“代词”
List<Integer> answer_; //每个词对应的label
List<Integer> result_;
String lastError;
PriorityQueue<QueueElement> agenda_;
List<List<Double>> penalty_;
List<List<Integer>> featureCache_; //缓存特征数据
com/hankcs/hanlp/model/crf/crfpp/FeatureIndex.java 的 buildFeatures 方法
public boolean buildFeatures(TaggerImpl tagger)
{
List<Integer> feature = new ArrayList<Integer>();
//存放是每个节点或者边对应的特征向量,节点便是node[i][j],边的概念后续会接触,暂时可以忽略
List<List<Integer>> featureCache = tagger.getFeatureCache_();
//做个标记,以后要取该句子的特征,可以从该id的位置取
tagger.setFeature_id_(featureCache.size());
//遍历每个词,计算每个词的特征
for (int cur = 0; cur < tagger.size(); cur++)
{
if (!buildFeatureFromTempl(feature, unigramTempls_, cur, tagger))
{
return false;
}
feature.add(-1);
featureCache.add(feature);
feature = new ArrayList<Integer>();
}
//遍历每条边,计算每条边的特征
for (int cur = 1; cur < tagger.size(); cur++)
{
if (!buildFeatureFromTempl(feature, bigramTempls_, cur, tagger))
{
return false;
}
feature.add(-1);
featureCache.add(feature);
feature = new ArrayList<Integer>();
}
return true;
}
com/hankcs/hanlp/model/crf/crfpp/FeatureIndex.java的
buildFeatureFromTempl(feature, unigramTempls_, cur, tagger)
bigramTempls_,unigramTempls_初始设置?
private boolean buildFeatureFromTempl(List<Integer> feature, List<String> templs, int curPos, TaggerImpl tagger)
{
for (String tmpl : templs)
{
// applyRule函数根据当前词(cur)以及当前的特征(如: %x[-2,0]),生成一个特征,存放在featureID
String featureID = applyRule(tmpl, curPos, tagger);
if (featureID == null || featureID.length() == 0)
{
System.err.println("format error");
return false;
}
//获取该featureID的id,如果不存在该特征,生成新的id,将该id添加到feature变量中,feature变量作为当前行训练语料构建的所有特征
int id = getID(featureID);
if (id != -1)
{
feature.add(id);
}
}
return true;
}
com/hankcs/hanlp/model/crf/crfpp/FeatureIndex.java 的applyRule()
public String applyRule(String str, int cur, TaggerImpl tagger)
{
StringBuilder sb = new StringBuilder();
for (String tmp : str.split("%x", -1))
{
if (tmp.startsWith("U") || tmp.startsWith("B"))
{
sb.append(tmp);
}
else if (tmp.length() > 0)
{
String[] tuple = tmp.split("]");
String[] idx = tuple[0].replace("[", "").split(",");
String r = getIndex(idx, cur, tagger);
if (r != null)
{
sb.append(r);
}
if (tuple.length > 1)
{
sb.append(tuple[1]);
}
}
}
return sb.toString();
}
代码中预定义好的特征模板
U0:%x[-2,0]
U1:%x[-1,0]
U2:%x[0,0]
U3:%x[1,0]
U4:%x[2,0]
U5:%x[-2,1]
U6:%x[-1,1]
U7:%x[0,1]
U8:%x[1,1]
U9:%x[2,1]
UA:%x[-2,1]%x[-1,1]
UB:%x[-1,1]%x[0,1]
UC:%x[0,1]%x[1,1]
UD:%x[1,1]%x[2,1]
UE:%x[2,1]%x[3,1]
B
com/hankcs/hanlp/model/crf/crfpp/EncoderFeatureIndex.java的 getID(String key)
public int getID(String key)
{
int k = dic_.get(key);
if (k == -1)
{
dic_.put(key, maxid_);
frequency.append(1);
int n = maxid_;
if (key.charAt(0) == 'U')
{
maxid_ += y_.size();
}
else
{
bId = n;
maxid_ += y_.size() * y_.size();
}
return n;
}
else
{
int cid = continuousId(k);
int oldVal = frequency.get(cid);
frequency.set(cid, oldVal + 1);
return k;
}
}