package RareProb_v2;
public class VertibeAlgorithm {
int hiddenStateNum;
int T;
double[] pi;
double[][] A;
double[][] B;
int[] ObserveSeq;
double[][] deta;
int[][] psi;
// int ObserveSeqState;//这个用来表示观察序列中有几种情况,如值为0,1,2,3这三种值,目的是用在使用B的时候需要
// ArrayList<Integer> observation = new ArrayList<Integer>(); // 观察到的集合
VertibeAlgorithm(int hiddenStateNum,int ObserveSeqNum,int[] ObserveSeq,double[] pi,double[][] A,double[][] B /*,int[] ObserveSeqState*/)
{
this.hiddenStateNum = hiddenStateNum;
this.T = ObserveSeqNum;
this.ObserveSeq = ObserveSeq;
this.pi = pi;
this.A = A;
this.B = B;
deta = new double[T][hiddenStateNum];
psi = new int [T][hiddenStateNum];
}
public void partialPr()
{
//t=0时的deta
for(int i=0; i<hiddenStateNum; i++)
{
deta[0][i] = pi[i]*B[i][ObserveSeq[0]-1];
//psi[0][i] = i;//psi初始化为他自己
}
//t=1时的deta
for(int t=1;t<T;t++)
{
for(int i=0;i<hiddenStateNum;i++)
{
double maxPr = 0;
double maxj = 0;
double temp = 0;
for(int j=0;j<hiddenStateNum;j++)
{
if( (deta[t][i] = deta[t-1][j]*A[j][i]*B[i][ObserveSeq[t]-1] ) > maxPr )
maxPr = deta[t][i];
if(( temp = deta[t-1][j]*A[j][i]) > maxj)
{
psi[t][i] = j;
maxj = temp;
}
}
deta[t][i] = maxPr;
System.out.println("psi["+t+"]"+"["+i+"]="+psi[t][i]);
}
}
}
public void getSequence()
{
int[] backSeq = new int[T];
backSeq[T-1] = 1;
double max = 0;
for(int i=0;i<hiddenStateNum;i++)
if(deta[T-1][i] > max)
{
max = deta[T-1][i];
backSeq[T-1]=i;
}
for(int t=T-2;t>=0;t--)
{
backSeq[t] = psi[t+1][backSeq[t+1]];
}
System.out.println("sequece:");
for(int t=0;t<T;t++)
System.out.println(backSeq[t]);
}
public static void main(String[] args)
{
int hiddenStateNum = 3;//hiddenstate用1,2,3表示,最后的结果也是这样的。这里的123和观察序列的12表示的不同的含义
int ObserveSeqNum = 10;
int[] ObserveSeq = {1,1,1,1,2,1, 2, 2, 2, 2};
double[] pi = {0.333 ,0.333, 0.333};
double[][] A = {{0.333,0.333,0.333},{0.333,0.333,0.333},{0.333,0.333,0.333}};
double[][] B = {{0.5,0.5},{0.75,0.25},{0.25 ,0.75}};
VertibeAlgorithm va = new VertibeAlgorithm(hiddenStateNum,ObserveSeqNum,ObserveSeq,pi,A,B);
va.partialPr();
va.getSequence();
}
}
转载请注明出处