em 算法 java_谁做过 EM算法 java实现

展开全部

参考:package nlp;

/**

* @author Orisun

* date 2011-10-22

*/

import java.util.ArrayList;

public class BaumWelch {

int M; // 隐藏状态的种数

int N; // 输出活动的种数

double[] PI; // 初始636f707962616964757a686964616f31333361323033状态概率矩阵

double[][] A; // 状态转移矩阵

double[][] B; // 混淆矩阵

ArrayList observation = new ArrayList(); // 观察到的集合

ArrayList state = new ArrayList(); // 中间状态集合

int[] out_seq = { 2, 1, 1, 1, 2, 2, 2, 2, 2, 1, 1, 1, 1, 2, 2, 2, 2, 1, 1,

1, 1, 1, 2, 2, 2, 1, 1, 1, 1, 1, 2, 1 }; // 测试用的观察序列

int[] hidden_seq = { 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1,

1, 1, 1, 1, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1 }; // 测试用的隐藏状态序列

int T = 32; // 序列长度为32

double[][] alpha = new double[T][]; // 向前变量

double PO;

double[][] beta = new double[T][]; // 向后变量

double[][] gamma = new double[T][];

double[][][] xi = new double[T - 1][][];

// 初始化参数。Baum-Welch得到的是局部最优解,所以初始参数直接影响解的好坏

public void initParameters() {

M = 2;

N = 2;

PI = new double[M];

PI[0] = 0.5;

PI[1] = 0.5;

A = new double[M][];

B = new double[M][];

for (int i = 0; i 

A[i] = new double[M];

B[i] = new double[N];

}

A[0][0] = 0.8125;

A[0][1] = 0.1875;

A[1][0] = 0.2;

A[1][1] = 0.8;

B[0][0] = 0.875;

B[0][1] = 0.125;

B[1][0] = 0.25;

B[1][1] = 0.75;

observation.add(1);

observation.add(2);

state.add(1);

state.add(2);

for (int t = 0; t 

alpha[t] = new double[M];

beta[t] = new double[M];

gamma[t] = new double[M];

}

for (int t = 0; t 

xi[t] = new double[M][];

for (int i = 0; i 

xi[t][i] = new double[M];

}

}

// 更新向前变量

public void updateAlpha() {

for (int i = 0; i 

alpha[0][i] = PI[i] * B[i][observation.indexOf(out_seq[0])];

}

for (int t = 1; t 

for (int i = 0; i 

alpha[t][i] = 0;

for (int j = 0; j 

alpha[t][i] += alpha[t - 1][j] * A[j][i];

}

alpha[t][i] *= B[i][observation.indexOf(out_seq[t])];

}

}

}

// 更新观察序列出现的概率,它在一些公式中当分母

public void updatePO() {

for (int i = 0; i 

PO += alpha[T - 1][i];

}

// 更新向后变量

public void updateBeta() {

for (int i = 0; i 

beta[T - 1][i] = 1;

}

for (int t = T - 2; t >= 0; t--) {

for (int i = 0; i 

for (int j = 0; j 

beta[t][i] += A[i][j]

* B[j][observation.indexOf(out_seq[t + 1])]

* beta[t + 1][j];

}

}

}

}

// 更新xi

public void updateXi() {

for (int t = 0; t 

double frac = 0.0;

for (int i = 0; i 

for (int j = 0; j 

frac += alpha[t][i] * A[i][j]

* B[j][observation.indexOf(out_seq[t + 1])]

* beta[t + 1][j];

}

}

for (int i = 0; i 

for (int j = 0; j 

xi[t][i][j] = alpha[t][i] * A[i][j]

* B[j][observation.indexOf(out_seq[t + 1])]

* beta[t + 1][j] / frac;

}

}

}

}

// 更新gamma

public void updateGamma() {

for (int t = 0; t 

double frac = 0.0;

for (int i = 0; i 

frac += alpha[t][i] * beta[t][i];

}

// double frac = PO;

for (int i = 0; i 

gamma[t][i] = alpha[t][i] * beta[t][i] / frac;

}

// for(int i=0;i

// gamma[t][i]=0;

// for(int j=0;j

// gamma[t][i]+=xi[t][i][j];

// }

}

}

// 更新状态概率矩阵

public void updatePI() {

for (int i = 0; i 

PI[i] = gamma[0][i];

}

// 更新状态转移矩阵

public void updateA() {

for (int i = 0; i 

double frac = 0.0;

for (int t = 0; t 

frac += gamma[t][i];

}

for (int j = 0; j 

double dem = 0.0;

// for (int t = 0; t 

// dem += xi[t][i][j];

// for (int k = 0; k 

// frac += xi[t][i][k];

// }

for (int t = 0; t 

dem += xi[t][i][j];

}

A[i][j] = dem / frac;

}

}

}

// 更新混淆矩阵

public void updateB() {

for (int i = 0; i 

double frac = 0.0;

for (int t = 0; t 

frac += gamma[t][i];

for (int j = 0; j 

double dem = 0.0;

for (int t = 0; t 

if (out_seq[t] == observation.get(j))

dem += gamma[t][i];

}

B[i][j] = dem / frac;

}

}

}

// 运行Baum-Welch算法

public void run() {

initParameters();

int iter = 22; // 迭代次数

while (iter-- > 0) {

// E-Step

updateAlpha();

// updatePO();

updateBeta();

updateGamma();

updatePI();

updateXi();

// M-Step

updateA();

updateB();

}

}

public static void main(String[] args) {

BaumWelch bw = new BaumWelch();

bw.run();

System.out.println("训练后的初始状态概率矩阵:");

for (int i = 0; i 

System.out.print(bw.PI[i] + "\t");

System.out.println();

System.out.println("训练后的状态转移矩阵:");

for (int i = 0; i 

for (int j = 0; j 

System.out.print(bw.A[i][j] + "\t");

}

System.out.println();

}

System.out.println("训练后的混淆矩阵:");

for (int i = 0; i 

for (int j = 0; j 

System.out.print(bw.B[i][j] + "\t");

}

System.out.println();

}

}

}

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值