HMM的原理就不说了,这里主要说算法的实现。
实际实现起来并不是很困难,前提是你仔细看过hmm的原理,然后很多实现就照着公式写出对应的代码,比如前向算法,后向算法,参数更新都是有明确的公式的,只需要对应写成代码,这里需要提到2点技巧。
1,所有概率需要取对数,这是因为有的概率实在是太小了,容易溢出,或者精度不够。
2.对一个求和的式子取对数概率时需要用到一个技巧。下面直接贴出我写的关于这个计算技巧的理解。
LogSum计算技巧
如果需要计算下面的式子:
其中α是一个概率,只知道 而不知道 ,此时如果直接计算会溢出,为了解决这个问题就可以用到这个logSum计算技巧。传入参数是一个数组,每个元素为 ,找到其中的最大值
,根据:
示例代码:
这是一个java版本实现的无监督HMM,包括学习算法和预测算法,算法没有错误,我已经做过多次测试,但是由于HMM的训练算法就是EM算法,而EM算法对初值十分敏感,所以训练时必须给定一些先验条件,即需要给定HMM中的参数pi,A,B至少其中一个,不然训练出来的参数将时一样的,毫无意义。当然无监督的HMM效果依然不如监督学习的HMM,我测试了一下分词,给定了一个监督学习HMM分词的参数来训练无监督的HMM,效果如下:
无监督HMM效果:给定参数pi和B
参数已收敛....
最终参数:
pi:[0.0, -2.1474836360090876E9, -2.147483633470365E9, -2.1474836334854264E9]
A:
[-2.1474836482889004E9, -2.2141536347013195, -0.115686913295864, -2.147483648337081E9]
[-2.1474836479972153E9, -1.0239098431251064, -0.4450188807874582, -2.1474836480612097E9]
[-0.7149451174254677, -2.1474836483350754E9, -2.147483648333808E9, -0.6718142750423626]
[-0.4322715044602754, -2.1474836481090927E9, -2.1474836481949987E9, -1.0470634684331648]
[原标题, :, 日, 媒拍, 到, 了, 现场, 罕见, 一幕, 据, 日本, 新闻, 网, (, N, NN, ), 9月, 8日, 报道, ,, 日前, ,, 日本, 海上, 自卫队, 现役, 最大, 战舰, 之一, 的, 直升, 机航, 母, “, 加贺, ”, 号在, 南, 海航, 行时, ,, 遭多, 艘, 中国, 海军, 战舰, 抵, 近跟, 踪, 监视, 。]
监督学习HMM:
[原, 标题, :, 日媒, 拍到, 了, 现场, 罕见, 一幕, 据, 日本, 新闻网, (, NN, N)9月8日, 报道, ,, 日前, ,, 日本, 海上, 自卫队, 现役, 最大, 战舰, 之, 一, 的, 直升, 机航母, “, 加贺, ”, 号, 在, 南海, 航行, 时, ,, 遭多, 艘, 中国, 海军, 战舰, 抵近, 跟踪, 监视, 。]
虽然没有指定参数A,但是可以看到学习出来的A还是有准确性,比如B转移到B的概率为0,B转移到S的概率为0,M转移到M的概率为0,M转移到S的概率为0.....这和监督学习的HMM一样的。
这样看来这个算法确实有效。
光从结果来看无监督的HMM指定了pi和B参数,整体效果还是差于监督学习的HMM。测试语料只有人民日报1998的分割语料。
由于原代码比较长,并且本来是写到我的开源项目中的,所以不简单是整合到一个类中就是所有代码,还包含了一些依赖。
这里我整理出了串行版本的只包好2个依赖的代码,供学习使用,由于训练中很多步骤都可以并行实现,所以我并行了一些消耗时间的步骤,要比串行的快得多。过几天我会更新到github上,完整的源码请参考我的开源项目:GitHub - colin0000007/CONLP: 一个自然语言处理初学者可以参考的库,包含分词,词性标注,命名实体识别,依存句法分析大多模型和算法都是自己实现 。a natural language processing library for beginners
代码中需要用到语料以及HMM的参数A和B:
语料以及参数A和B.rar_免费高速下载|百度网盘-分享无限制
下面是串行版本的代码,:
package com.outsider.test;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.logging.Logger;
/**
*
* 无监督学习的HMM实现
* 少量数据建议串行
* 大量数据,几十万,百万甚至更高的数据强烈建议并行训练,性能是串行的好4倍以上
* @author outsider
*/
public class UnsupervisedFirstOrderGeneralHMM{
private double precision = 1e-7;
/**
* 训练数据长度
*/
private int sequenceLen;
public Logger logger = Logger.getLogger(UnsupervisedFirstOrderGeneralHMM.class.getName());
/**初始状态概率**/
protected double[] pi;
/**转移概率**/
protected double[][] transferProbability1;
/**发射概率**/
protected double[][] emissionProbability;
/**定义无穷大**/
public static final double INFINITY = (double) -Math.pow(2, 31);
/**状态值集合的大小**/
protected int stateNum;
/**观测值集合的大小**/
protected int observationNum;
public UnsupervisedFirstOrderGeneralHMM() {
super();
}
public UnsupervisedFirstOrderGeneralHMM(int stateNum, int observationNum, double[] pi,
double[][] transferProbability1, double[][] emissionProbability) {
this.stateNum = stateNum;
this.observationNum = observationNum;
this.pi = pi;
this.transferProbability1 = transferProbability1;
this.emissionProbability = emissionProbability;
}
public UnsupervisedFirstOrderGeneralHMM(int stateNum, int observationNum) {
this.stateNum = stateNum;
this.observationNum = observationNum;
initParameters();
}
/**
* λ是HMM参数的总称
*/
/**
* 训练方法
* @param x 训练序列数据
* @param maxIter 最大迭代次数
* @param precision 精度
*/
public void train(int[] x, int maxIter, double precision) {
this.sequenceLen = x.length;
baumWelch(x, maxIter, precision);
}
public void train(int[] x) {
this.sequenceLen = x.length;
//不做概率归一化
}
/**
* baumWelch算法迭代求解
* 迭代时存在这样的现象:新参数和上一次的参数差反而会变大,但是到后面这个误差值几乎会收敛
* 所以迭代终止的条件有2个:
* 1.达到最大迭代次数
* 2.参数A,B,pi中的值相比上一次的最大误差小于某个精度值则认为收敛
* 3.若1中给的精度值太大,则可能导致无法收敛,所以增加了一个条件,如果当前迭代的误差和上一次迭代的误差小于某个值(这里给定1e-7),
* 可以认为收敛了。
* @param x 观测序列
* @param maxIter 最大迭代次数,如果传入<=0的数则默认为Integer.MAX_VALUE,相当于不收敛就不跳出循环
* @param precision 参数误差的精度小于precision就认为收敛
*/
protected void baumWelch(int[] x, int maxIter, double precision) {
int iter = 0;
double oldMaxError = 0;
if(maxIter <= 0) {
maxIter = Integer.MAX_VALUE;
}
//初始化各种参数
double[][] alpha = new double[sequenceLen][stateNum];
double[][] beta = new double[sequenceLen][stateNum];
double[][] gamma = new double[sequenceLen][stateNum];
double[][][] ksi = new double[sequenceLen][stateNum][stateNum];
while(iter < maxIter) {
logger.info("\niter"+iter+"...");
long start = System.currentTimeMillis();
//计算各种参数,为更新模型参数做准备,对应EM中的E步
calcAlpha(x, alpha);
calcBeta(x, beta);
calcGamma(x, alpha, beta, gamma);
calcKsi(x, alpha, beta, ksi);
//更新参数,对应EM中的M步
double[][] oldA = generateOldA();
//double[][] oldB = generateOldB();
//double[] oldPi = pi.clone();
updateLambda(x, gamma, ksi);
//double maxError = calcError(oldA, oldPi, oldB);
double maxError = calcError(oldA, null, null);
logger.info("max_error:"+maxError);
if(maxError < precision || (Math.abs(maxError-oldMaxError)) < this.precision) {
logger.info("参数已收敛....");
break;
}
oldMaxError = maxError;
iter++;
long end = System.currentTimeMillis();
logger.info("本次迭代结束,耗时:"+(end - start)+"毫秒");
}
logger.info("最终参数:");
logger.info("pi:"+Arrays.toString(pi));
logger.info("A:");
for(int i = 0; i < transferProbability1.length; i++) {
logger.info(Arrays.toString(transferProbability1[i]));
}
}
/**
* 保存旧的参数A
* @return
*/
protected double[][] generateOldA() {
double[][] oldA = new double[stateNum][stateNum];
for(int i = 0; i < stateNum; i++) {
for(int j = 0; j < stateNum; j++) {
oldA[i][j] = transferProbability1[i][j];
}
}
return oldA;
}
/**
* 保存旧的参数B
* @return
*/
protected double[][] generateOldB() {
double[][] oldB = new double[stateNum][observationNum];
for(int i = 0; i < stateNum; i++) {
for(int j = 0; j < observationNum; j++) {
oldB[i][j] = emissionProbability[i][j];
}
}
return oldB;
}
/**
* 暂时只计算参数A的误差
* 发现计算B和pi会发现参数误差越来越大的现象,基本不能收敛
* @param old
* @return
*/
protected double calcError(double[][] oldA, double[] oldPi, double[][] oldB) {
double maxError = 0;
for(int i =0 ; i < stateNum; i++) {
/*double tmp1 = Math.abs(pi[i] - oldPi[i]);
maxError = tmp1 > maxError ? tmp1 : maxError;*/
for(int j =0; j < stateNum; j++) {
double tmp = Math.abs(oldA[i][j] - transferProbability1[i][j]);
maxError = tmp > maxError ? tmp : maxError;
}
/*for(int k =0; k < observationNum; k++) {
double tmp2 = Math.abs(emissionProbability[i][k] - oldB[i][k]);
maxError = tmp2 > maxError ? tmp2 : maxError;
}*/
}
return maxError;
}
/**
* 概率初始化为0
*/
public void initParameters() {
//初始概率随机初始化
pi = new double[stateNum];
transferProbability1 = new double[stateNum][stateNum];
emissionProbability = new double[stateNum][observationNum];
//概率初始化为0
for(int i = 0; i < stateNum; i++) {
pi[i] = INFINITY;
for(int j = 0; j < stateNum; j++) {
transferProbability1[i][j] = INFINITY;
}
for(int k = 0; k < observationNum; k++) {
emissionProbability[i][k] = INFINITY;
}
}
}
/**
* 数组求和
* @param arr
* @return
*/
public static double sum(double[] arr) {
double sum = 0;
for(int i = 0; i < arr.length;i++) {
sum += arr[i];
}
return sum;
}
/**
* 随机初始化参数PI
*/
public void randomInitPi() {
for(int i = 0; i < stateNum; i++) {
pi[i] = Math.random() * 100;
}
//log归一化
double sum = Math.log(sum(pi));
for(int i =0; i < stateNum; i++) {
if(pi[i] == 0) {
pi[i] = INFINITY;
continue;
}
pi[i] = Math.log(pi[i]) - sum;
}
}
/**
* 随机初始化参数A
*/
public void randomInitA() {
for(int i = 0; i < stateNum; i++) {
for(int j = 0; j < stateNum; j++) {
transferProbability1[i][j] = Math.random()*100;;
}
double sum = Math.log(sum(transferProbability1[i]));
for(int k = 0; k < stateNum; k++) {
if(transferProbability1[i][k] == 0) {
transferProbability1[i][k] = INFINITY;
continue;
}
transferProbability1[i][k] = Math.log(transferProbability1[i][k]) - sum;
}
}
}
/**
* 随机初始化参数B
*/
public void randomInitB() {
for(int i = 0; i < stateNum; i++) {
for(int j = 0; j < observationNum; j++) {
emissionProbability[i][j] = Math.random()*100;;
}
double sum = Math.log(sum(emissionProbability[i]));
for(int k = 0; k < observationNum; k++) {
if(emissionProbability[i][k] == 0) {
emissionProbability[i][k] = INFINITY;
continue;
}
emissionProbability[i][k] = Math.log(emissionProbability[i][k]) - sum;
}
}
}
/**
* 随机初始化所有参数
*/
public void randomInitAllParameters() {
randomInitA();
randomInitB();
randomInitPi();
}
/**
* 前向算法,根据当前参数λ计算α
* α是一个序列长度*状态长度的矩阵
* 已检测,应该没有问题
*/
protected void calcAlpha(int[] x, double[][] alpha) {
logger.info("计算alpha...");
long start = System.currentTimeMillis();
//double[][] alpha = new double[sequenceLen][stateNum];
//alpha t=0初始值
for(int i = 0; i < stateNum; i++) {
alpha[0][i] = pi[i] + emissionProbability[i][x[0]];
}
double[] logProbaArr = new double[stateNum];
for(int t = 1; t < sequenceLen; t++) {
for(int i = 0; i < stateNum; i++) {
for(int j = 0; j < stateNum; j++) {
logProbaArr[j] = (alpha[t -1][j] + transferProbability1[j][i]);
}
alpha[t][i] = logSum(logProbaArr) + emissionProbability[i][x[t]];
}
}
long end = System.currentTimeMillis();
logger.info("计算结束...耗时:"+ (end - start) +"毫秒");
//return alpha;
}
/**
* 后向算法,根据当前参数λ计算β
*
* @param x
*/
protected void calcBeta(int[] x, double[][] beta) {
logger.info("计算beta...");
long start = System.currentTimeMillis();
//double[][] beta = new double[sequenceLen][stateNum];
//初始概率beta[T][i] = 1
for(int i = 0; i < stateNum; i++) {
beta[sequenceLen-1][i] = 1;
}
double[] logProbaArr = new double[stateNum];
for(int t = sequenceLen -2; t >= 0; t--) {
for(int i = 0; i < stateNum; i++) {
for(int j = 0; j < stateNum; j++) {
logProbaArr[j] = transferProbability1[i][j] +
emissionProbability[j][x[t+1]] +
beta[t + 1][j];
}
beta[t][i] = logSum(logProbaArr);
}
}
long end = System.currentTimeMillis();
logger.info("计算结束...耗时:"+ (end - start) +"毫秒");
//return beta;
}
/**
* 根据当前参数λ计算ξ
* @param x 观测结点
* @param alpha 前向概率
* @param beta 后向概率
*/
protected void calcKsi(int[] x, double[][] alpha, double[][] beta, double[][][] ksi) {
logger.info("计算ksi...");
long start = System.currentTimeMillis();
//double[][][] ksi = new double[sequenceLen][stateNum][stateNum];
double[] logProbaArr = new double[stateNum * stateNum];
for(int t = 0; t < sequenceLen -1; t++) {
int k = 0;
for(int i = 0; i < stateNum; i++) {
for(int j = 0; j < stateNum; j++) {
ksi[t][i][j] = alpha[t][i] + transferProbability1[i][j] +
emissionProbability[j][x[t+1]]+beta[t+1][j];
logProbaArr[k++] = ksi[t][i][j];
}
}
double logSum = logSum(logProbaArr);//分母
for(int i = 0; i < stateNum; i++) {
for(int j = 0; j < stateNum; j++) {
ksi[t][i][j] -= logSum;//分子除分母
}
}
}
long end = System.currentTimeMillis();
logger.info("计算结束...耗时:"+ (end - start) +"毫秒");
//return ksi;
}
/**
* 根据当前参数λ,计算γ
* @param x
*/
protected void calcGamma(int[] x, double[][] alpha, double[][] beta, double[][] gamma) {
logger.info("计算gamma...");
long start = System.currentTimeMillis();
//double[][] gamma = new double[sequenceLen][stateNum];
for(int t = 0; t < sequenceLen; t++) {
//分母需要求LogSum
for(int i = 0; i < stateNum; i++) {
gamma[t][i] = alpha[t][i] + beta[t][i];
}
double logSum = logSum(gamma[t]);//分母部分
for(int j = 0; j < stateNum; j++) {
gamma[t][j] = gamma[t][j] - logSum;
}
}
long end = System.currentTimeMillis();
logger.info("计算结束...耗时:"+ (end - start) +"毫秒");
//return gamma;
}
/**
* 更新参数
*/
protected void updateLambda(int[] x ,double[][] gamma, double[][][] ksi) {
//顺序可以颠倒
updatePi(gamma);
updateA(ksi, gamma);
updateB(x, gamma);
}
/**
* 更新参数pi
* @param gamma
*/
public void updatePi(double[][] gamma) {
//更新HMM中的参数pi
for(int i = 0; i < stateNum; i++) {
pi[i] = gamma[0][i];
}
}
/**
* 更新参数A
* @param ksi
* @param gamma
*/
protected void updateA(double[][][] ksi, double[][] gamma) {
logger.info("更新参数转移概率A...");
由于在更新A都要用到对不同状态的前T-1的gamma值求和,所以这里先算
double[] gammaSum = new double[stateNum];
double[] tmp = new double[sequenceLen -1];
for(int i = 0; i < stateNum; i++) {
for(int t = 0; t < sequenceLen -1; t++) {
tmp[t] = gamma[t][i];
}
gammaSum[i] = logSum(tmp);
}
long start1 = System.currentTimeMillis();
//更新HMM中的参数A
double[] ksiLogProbArr = new double[sequenceLen - 1];
for(int i = 0; i < stateNum; i++) {
for(int j = 0; j < stateNum; j++) {
for(int t = 0; t < sequenceLen -1; t++) {
ksiLogProbArr[t] = ksi[t][i][j];
}
transferProbability1[i][j] = logSum(ksiLogProbArr) - gammaSum[i];
}
}
long end1 = System.currentTimeMillis();
logger.info("更新完毕...耗时:"+(end1 - start1)+"毫秒");
}
/**
* 更新参数B
* @param x
* @param gamma
*/
protected void updateB(int[] x, double[][] gamma) {
//下面需要用到gamma求和为了减少重复计算,这里直接先计算
//由于在更新B时都要用到对不同状态的所有gamma值求和,所以这里先算
double[] gammaSum2 = new double[stateNum];
double[] tmp2 = new double[sequenceLen];
for(int i = 0; i < stateNum; i++) {
for(int t = 0; t < sequenceLen; t++) {
tmp2[t] = gamma[t][i];
}
gammaSum2[i] = logSum(tmp2);
}
logger.info("更新状态下分布概率B...");
long start2 = System.currentTimeMillis();
ArrayList<Double> valid = new ArrayList<Double>();
for(int i = 0; i < stateNum; i++) {
for(int k = 0; k < observationNum; k++) {
valid.clear();//由于这里没有初始化造成了计算出错的问题
for(int t = 0; t < sequenceLen; t++) {
if(x[t] == k) {
valid.add(gamma[t][i]);
}
}
//B[i][k],i状态下k的分布为概率0,
if(valid.size() == 0) {
emissionProbability[i][k] = INFINITY;
continue;
}
//对分子求logSum
double[] validArr = new double[valid.size()];
for(int q = 0; q < valid.size(); q++) {
validArr[q] = valid.get(q);
}
double validSum = logSum(validArr);
//分母的logSum已经在上面做了
emissionProbability[i][k] = validSum - gammaSum2[i];
}
}
long end2 = System.currentTimeMillis();
logger.info("更新完毕...耗时:"+(end2 - start2)+"毫秒");
}
/**
* logSum计算技巧
* @param tmp
* @return
*/
public double logSum(double[] logProbaArr) {
if(logProbaArr.length == 0) {
return INFINITY;
}
double max = max(logProbaArr);
double result = 0;
for(int i = 0; i < logProbaArr.length; i++) {
result += Math.exp(logProbaArr[i] - max);
}
return max + Math.log(result);
}
/**
* 设置先验概率pi
* 必须传入取对数后的概率
* @param pi
*/
public void setPriorPi(double[] pi){
this.pi = pi;
}
/**
* 设置先验转移概率A
* 必须传入取对数的概率
* @param trtransferProbability1
*/
public void setPriorTransferProbability1(double[][] trtransferProbability1){
this.transferProbability1 = trtransferProbability1;
}
/**
* 设置先验状态下的观测分布概率,B
* 必须传入取对数的概率
* @param emissionProbability
*/
public void setPriorEmissionProbability(double[][] emissionProbability) {
this.emissionProbability = emissionProbability;
}
public static double max(double[] arr) {
double max = arr[0];
for(int i = 1; i < arr.length;i++) {
max = arr[i] > max ? arr[i] : max;
}
return max;
}
/**
* 维特比解码
* @param O 观测序列,输入的是经过编码处理的,而不是原始数据,
* 比如,如果序列是字符串,那么输入必须是一系列的字符的编码而不是字符本身
* @return 返回预测结果,
*/
public int[] verterbi(int[] O) {
double[][] deltas = new double[O.length][this.stateNum];
//保存deltas[t][i]的值是由上一个哪个状态产生的
int[][] states = new int[O.length][this.stateNum];
//初始化deltas[0][]
for(int i = 0;i < this.stateNum; i++) {
deltas[0][i] = pi[i] + emissionProbability[i][O[0]];
}
//计算deltas
for(int t = 1; t < O.length; t++) {
for(int i = 0; i < this.stateNum; i++) {
deltas[t][i] = deltas[t-1][0]+transferProbability1[0][i];
for(int j = 1; j < this.stateNum; j++) {
double tmp = deltas[t-1][j]+transferProbability1[j][i];
if (tmp > deltas[t][i]) {
deltas[t][i] = tmp;
states[t][i] = j;
}
}
deltas[t][i] += emissionProbability[i][O[t]];
}
}
//回溯找到最优路径
int[] predict = new int[O.length];
double max = deltas[O.length-1][0];
for(int i = 1; i < this.stateNum; i++) {
if(deltas[O.length-1][i] > max) {
max = deltas[O.length-1][i];
predict[O.length-1] = i;
}
}
for(int i = O.length-2;i >= 0;i-- ) {
predict[i] = states[i+1][predict[i+1]];
}
return predict;
}
//测试
public static void main(String[] args) {
UnsupervisedFirstOrderGeneralHMM hmm = new UnsupervisedFirstOrderGeneralHMM(4, 65536);
//关闭日志打印
//CONLPLogger.closeLogger(hmm.logger);
//由于是监督学习的语料所以这里需要去掉其中的分隔符
String path = "src/pku_training.splitBy2space.utf8";
String data = IOUtils.readText(path, "utf-8");
String[] d2 = data.split(" ");
StringBuilder sb = new StringBuilder();
for(String word : d2) {
sb.append(word);
}
data = sb.toString();
//训练数据
int[] x = SegmentationUtils.str2int(data);
//由于串行很慢,可以只取训练数据的前10000个来训练
int[] minX = new int[10000];
System.arraycopy(x, 0, minX, 0, 10000);
//训练之前设置先验概率,必须设置,EM对初始值敏感,如果不设置默认为都为0,所有参数都将一样,没有意义
//如果只给了其中一些参数的先验值,可以随机初始化其他参数,例如
//hmm.randomInitA();
//hmm.randomInitB();
//hmm.randomInitPi();
//hmm.randomInitAllParameters();
//设置先验信息至少设置参数pi,A,B中的一个
hmm.setPriorPi(new double[] {-1.138130826175848, -2.632826946498266, -1.138130826175848, -1.2472622308278396});
hmm.setPriorTransferProbability1((double[][]) IOUtils.readObject("src/A"));
hmm.setPriorEmissionProbability((double[][]) IOUtils.readObject("src/B"));
//开始训练
hmm.train(minX, -1, 0.5);
String str = "原标题:日媒拍到了现场罕见一幕" +
"据日本新闻网(NNN)9月8日报道,日前,日本海上自卫队现役最大战舰之一的直升机航母“加贺”号在南海航行时,遭多艘中国海军战舰抵近跟踪监视。" ;
//将词转换为对应的Unicode码
int[] O = SegmentationUtils.str2int(str);
int[] predict = hmm.verterbi(O);
System.out.println(Arrays.toString(predict));
String[] res = SegmentationUtils.decode(predict, str);
System.out.println(Arrays.toString(res));
}
}
依赖IoUtils:
package com.outsider.test;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.OutputStreamWriter;
import java.io.UnsupportedEncodingException;
import java.util.ArrayList;
import java.util.List;
public class IOUtils {
public static String readTextWithLineCheckBreak(String path, String encoding) {
return readText(path, encoding, "\n");
}
/**
* 读取文本文件,返回整个字符串,不包括换行符号
* @param path 文件路径
* @param encoding 编码,传入null或者空串使用默认编码
* @return
*/
public static String readText(String path, String encoding) {
return readText(path, encoding, null);
}
/**
* 读取文本,指定每一行末尾符号
* @param path
* @param encoding
* @param lineEndStr
* @return
*/
public static String readText(String path, String encoding, String lineEndStr) {
try {
if(lineEndStr == null) {
lineEndStr = "";
}
BufferedReader reader = null;
if((!encoding.trim().equals(""))&&encoding!=null) {
reader = new BufferedReader(new InputStreamReader(new FileInputStream(path),encoding));
} else {
reader = new BufferedReader(new InputStreamReader(new FileInputStream(path)));
}
String s="";
StringBuilder sb = new StringBuilder();
while((s=reader.readLine())!=null) {
sb.append(s+lineEndStr);
}
reader.close();
return sb.toString();
} catch (UnsupportedEncodingException e) {
e.printStackTrace();
} catch (FileNotFoundException e) {
e.printStackTrace();
} catch (IOException e) {
e.printStackTrace();
}
return null;
}
/**
* 读取文本文件,返回整个字符串,不包括换行符号
* @param path 文件路径
* @param encoding 编码,传入null或者空串使用默认编码
* @param addNewLine 是否加换行符
* @return
*/
public static List<String> readTextAndReturnLinesCheckLineBreak(String path, String encoding, boolean addNewLine) {
try {
String lineBreak;
if(addNewLine) {
lineBreak = "\n";
} else {
lineBreak = "";
}
BufferedReader reader = null;
if((!encoding.trim().equals(""))&&encoding!=null) {
reader = new BufferedReader(new InputStreamReader(new FileInputStream(path),encoding));
} else {
reader = new BufferedReader(new InputStreamReader(new FileInputStream(path)));
}
String s="";
List<String> list = new ArrayList<>();
while((s=reader.readLine())!=null) {
list.add(s+lineBreak);
}
reader.close();
return list;
} catch (UnsupportedEncodingException e) {
e.printStackTrace();
} catch (FileNotFoundException e) {
e.printStackTrace();
} catch (IOException e) {
e.printStackTrace();
}
return null;
}
public static List<String> readTextAndReturnLines(String path, String encoding){
return readTextAndReturnLinesCheckLineBreak(path, encoding, false);
}
/**
* 读取文本的每一行
* 并且返回数组形式
* @param path
* @param encoding
* @return
*/
public static String[] readTextAndReturnLinesOfArray(String path, String encoding){
List<String> lines = readTextAndReturnLines(path, encoding);
String[] arr = new String[lines.size()];
lines.toArray(arr);
return arr;
}
/**
* 写入文本文件
* @param data
* @param path
* @param encoding
*/
public static void writeTextData2File(String data,String path,String encoding) {
try {
BufferedWriter writer = null;
if((!encoding.trim().equals(""))&&encoding!=null) {
writer = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(path),encoding));
} else {
writer = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(path)));
}
writer.write(data);
writer.close();
} catch (UnsupportedEncodingException e) {
e.printStackTrace();
} catch (FileNotFoundException e) {
e.printStackTrace();
} catch (IOException e) {
e.printStackTrace();
}
}
/**
* 把对象写入文件
* @param path
* @param object
*/
public static void writeObject2File(String path, Object object) {
try {
ObjectOutputStream out = new ObjectOutputStream(new FileOutputStream(path));
out.writeObject(object);
out.close();
} catch (Exception e) {
e.printStackTrace();
}
}
/**
* 读取对象
* @param path
* @return
*/
public static Object readObject(String path) {
try {
ObjectInputStream in = new ObjectInputStream(new FileInputStream(path));
return in.readObject();
} catch (Exception e) {
e.printStackTrace();
}
return null;
}
}
依赖的SegmentationUtils:
package com.outsider.test;
import java.util.ArrayList;
import java.util.List;
public class SegmentationUtils {
/**
* 将字符串数组的每一个字符串中的字符直接转换为Unicode码
* @param strs 字符串数组
* @return Unicode值
*/
public static List<int[]> strs2int(String[] strs) {
List<int[]> res = new ArrayList<>(strs.length);
for(int i = 0; i < strs.length;i++) {
int[] O = new int[strs[i].length()];
for(int j = 0; j < strs[i].length();j++) {
O[j] = strs[i].charAt(j);
}
res.add(O);
}
return res;
}
public static int[] str2int(String str) {
return strs2int(new String[] {str}).get(0);
}
/**
* 根据预测结果解码
* BEMS 0123
* @param predict 预测结果
* @param sentence 句子
* @return
*/
public static String[] decode(int[] predict, String sentence) {
List<String> res = new ArrayList<>();
char[] chars = sentence.toCharArray();
for(int i = 0; i < predict.length;i++) {
if(predict[i] == 0 || predict[i] == 1) {
int a = i;
while(predict[i] != 2) {
i++;
if(i == predict.length) {
break;
}
}
int b = i;
if(b == predict.length) {
b--;
}
res.add(new String(chars,a,b-a+1));
} else {
res.add(new String(chars,i,1));
}
}
String[] s = new String[res.size()];
return res.toArray(s);
}
}