前面已经介绍了隐马尔可夫模型,本篇博文主要是介绍用 viterbi 算法来解决 HMM 中的预测问题,也称为解码问题。
维特比算法实际是用动态规划解隐马尔可夫模型预测问题,即用动态规划(dynamic programming)求概率最大路径(最优路径)。这时一条路径对应着一个状态序列。
根据动态规划原理,最优路径具有这样的特性:如果最优路径在时刻t通过
(it)∗
,那么这一路径从
it∗
到终点
iT∗
的部分路径,对于从
it∗
到
iT∗
的所有可能的部分路径来说,必须是最优的。因为假如不是这样,那么从
i1∗
到终点
iT∗
就有另一条更好的部分路径存在,如果把它和
i1∗
到终点
it∗
的部分路径连接起来,就会形成一条比原来的路径更优的路径,这是矛盾的。依据这一原理,我们只需从时刻t=1开始,递推地计算在时刻t状态为i的各条部分路径的最大概率,直至得到时刻
t=T
状态为i的各条路径的最大概率。时刻
t=T
的最大概率即为最优路径的概率
P∗
,最优路径的终结点
iT∗
也同时得到。之后,为了找出最优路径的各个结点,从终结点
iT∗
开始,由后向前逐步求得结点
iT−1∗,...,i1∗
得到最优路径这就是维特比算法。
viterbi 算法
输入:模型 λ=(A,B,π) 和观测 O=(o1,o2,...,oT) ;
输出:最优路径 (i1∗,...,iT−1∗,iT∗) .
(1) 初始化
δ1(i)=πibi(oi),i=1,2,...,N
ψ1(i)=0,i=1,2,...,N
(2) 递推.对 t=2,3,...,T
δt(i)=max[δt−1(j)aji]bi(ot),i=1,2,..,N;1≤j≤N
ψt(i)=argmax[δt−1(j)aji],i=1,2,...,N;1≤j≤N
(3) 终止
P∗=maxδT(i),1≤j≤N
iT∗=argmax[δT(i)],1≤j≤N
(4)最优路径回溯. 对 t=T−1,T−2,...,1
it∗=ψt+1(i∗t+1)viterbi算法实现
package com.feng.nlp.algorithm;
import java.util.*;
/**
* Created by lionel on 17/4/11.
*/
public class Viterbi {
public static List<String> compute(String[] observe, String[] status, double[] start_p, double[][] transfer_p, double[][] observe_p) {
double[][] theta = new double[observe.length][status.length];
int[][] delta = new int[observe.length][status.length];
transfermation(start_p, transfer_p, observe_p);
for (int j = 0; j < status.length; j++) {
theta[0][j] = start_p[j] + observe_p[j][0];
delta[0][j] = 0;
}
Map<String, Integer> map = new HashMap<String, Integer>();
int index = 0;
for (String ele : observe) {
if (map.containsKey(ele)) {
continue;
}
map.put(ele, index);
index++;
}
for (int i = 1; i < observe.length; i++) {
for (int j = 0; j < status.length; j++) {
int direction = 0;
double prob = Double.MAX_VALUE;
for (int k = 0; k < status.length; k++) {
double tmpProb = theta[i - 1][k] + transfer_p[k][j] + observe_p[j][map.get(observe[i])];
if (tmpProb < prob) {
prob = tmpProb;
direction = k;
theta[i][j] = prob;
}
}
delta[i][j] = direction;
}
}
// for (int i = 0; i < theta.length; i++) {
// for (int j = 0; j < theta[i].length; j++) {
// System.out.print(theta[i][j] + " ");
// }
// System.out.println();
// }
double prob = Double.MAX_VALUE;
int pos = 0;
for (int j = 0; j < status.length; j++) {
if (theta[observe.length - 1][j] < prob) {
prob = theta[observe.length - 1][j];
pos = j;
}
}
List<String> res = new ArrayList<String>();
res.add(status[pos]);
//回溯路径
for (int i = observe.length - 1; i > 0; i--) {
res.add(status[delta[i][pos]]);
pos = delta[i][pos];
}
Collections.reverse(res);
return res;
}
public static void transfermation(double[] start_p, double[][] transfer_p, double[][] observe_p) {
for (int i = 0; i < start_p.length; ++i) {
start_p[i] = -Math.log(start_p[i]);
}
for (int i = 0; i < transfer_p.length; ++i) {
for (int j = 0; j < transfer_p[i].length; ++j) {
transfer_p[i][j] = -Math.log(transfer_p[i][j]);
}
}
for (int i = 0; i < observe_p.length; ++i) {
for (int j = 0; j < observe_p[i].length; ++j) {
observe_p[i][j] = -Math.log(observe_p[i][j]);
}
}
}
public static void main(String[] args) {
String[] observe = {"红", "白", "红"};
String[] status = {"1", "2", "3"};
double[] start_p = new double[]{0.2, 0.4, 0.4};
double[][] transfer_p = new double[][]{
{0.5, 0.2, 0.3},
{0.3, 0.5, 0.2},
{0.2, 0.3, 0.5}
};
double[][] observe_p = new double[][]{
{0.5, 0.5},
{0.4, 0.6},
{0.7, 0.3}
};
List<String> result = compute(observe, status, start_p, transfer_p, observe_p);
System.out.println(result);//[3, 3, 3]
}
}
测试用例来源于李航老师的《统计机器学习》的例子。
- 参考资料:《统计机器学习》,李航