import java.util.*;
/**
* https://zh.wikipedia.org/zh-hans/%E7%BB%B4%E7%89%B9%E6%AF%94%E7%AE%97%E6%B3%95
* https://vimsky.com/article/863.html
* @author gdretop
* @date 2019/7/12
*/
public class Viterbi {
private static final String HEALTHY = "Healthy";
private static final String FEVER = "Fever";
private static final String NORMAL = "normal";
private static final String COLD = "cold";
private static final String DIZZY = "dizzy";
static Set<String> states = new HashSet<>(Arrays.asList(HEALTHY, FEVER));
static String[] observations = {NORMAL, COLD, DIZZY};
static Map<String, Double> startProbability = new HashMap<String, Double>() {
{
put(HEALTHY, 0.6);
put(FEVER, 0.4);
}
};
static Map<String, Map<String, Double>> transitionProbility = new HashMap<String, Map<String, Double>>() {
{
put(HEALTHY, new HashMap<String, Double>() {
{
put(HEALTHY, 0.7);
put(FEVER, 0.3);
}
});
put(FEVER, new HashMap<String, Double>() {
{
put(HEALTHY, 0.4);
put(FEVER, 0.6);
}
});
}
};
static Map<String, Map<String, Double>> emissionProbalility = new HashMap<String, Map<String, Double>>() {
{
put(HEALTHY, new HashMap<String, Double>() {
{
put(NORMAL, 0.5);
put(COLD, 0.4);
put(DIZZY, 0.1);
}
});
put(FEVER, new HashMap<String, Double>() {
{
put(NORMAL, 0.1);
put(COLD, 0.3);
put(DIZZY, 0.6);
}
});
}
};
public static void main(String[] args) {
viterbi(observations, states, startProbability, transitionProbility, emissionProbalility);
}
public static void viterbi(String[] observations, Set<String> states, Map<String, Double> startProbability,
Map<String, Map<String, Double>> transitionProbility,
Map<String, Map<String, Double>> emissionProbalility) {
List<Map<String, Double>> visibleProbality = new ArrayList<>(observations.length);
List<Map<String, String>> path = new ArrayList<>(observations.length);
for (int i = 0; i < observations.length; i++) {
visibleProbality.add(new HashMap<>());
path.add(new HashMap<>());
}
// 初始化概率 第一天的概率是: 可能状态*今天观察的结果
states.forEach(state -> {
visibleProbality.get(0).put(state,
startProbability.get(state) * emissionProbalility.get(state).get(observations[0]));
path.get(0).put(state, state);
});
// 根据每天的观察值,求今天不同状态下的最大概率
for (int i = 1; i < observations.length; i++) {
for (String nowState : states) {
String yesterState = "";
double probility = 0;
for (String oldState : states) {
double p = visibleProbality.get(i - 1).get(oldState)
* transitionProbility.get(oldState).get(nowState)
* emissionProbalility.get(nowState).get(observations[i]);
if (p > probility) {
probility = p;
yesterState = oldState;
}
}
visibleProbality.get(i).put(nowState, probility);
path.get(i).put(nowState, yesterState);
}
}
// 输出每天的状态最可能从昨天哪种状态转移过来,已经对应概率
String line1 = "";
String line2 = "";
for (int i = observations.length - 1; i >= 0; i--) {
line1 = i + " " + HEALTHY + ":" + visibleProbality.get(i).get(HEALTHY) + " " + path.get(i).get(HEALTHY);
line2 = i + " " + FEVER + ":" + visibleProbality.get(i).get(FEVER) + " " + path.get(i).get(FEVER);
System.out.println(line1);
System.out.println(line2);
}
// 输出概率最大的路径
if (visibleProbality.get(observations.length - 1).get(HEALTHY) >
visibleProbality.get(observations.length - 1).get(FEVER)) {
output(HEALTHY, visibleProbality, path, observations.length - 1);
} else {
output(FEVER, visibleProbality, path, observations.length - 1);
}
}
public static void output(String state, List<Map<String, Double>> visibleProbality,
List<Map<String, String>> path, int depth) {
if (depth < 0) { return; }
System.out.println(state + " " + visibleProbality.get(depth).get(state));
output(path.get(depth).get(state), visibleProbality, path, depth - 1);
}
}
分词用到结巴,学习了一下viterbi算法。用java改写一下。
python的写法确实比java方便多了。 如果把字符串映射成数字代码会少点。