HMM 隐马尔科夫链 viterbi算法 Pytorch实现

  • 基于pytorch实现了维特比算法,用于求解隐马尔科夫链的预测问题,由于重点在于理解算法并能用code表达出来,所以并未对于状态转移矩阵、发射概率矩阵和初始状态概率进行估计。实际上A,B,Pi的估计比起viterbi算法本身难度要低很多,直接基于训练数据,用类似矩估计的思想去统计频率直接得到A,B,Pi即可。
  • 代码中添加了很详细的注释,新人第一次发博客,希望能够帮助像我一样的小白理解viterbi算法,毕竟无论是HMM还是CRF,viterbi都是其中的灵魂,而弄清楚viterbi算法也能够让自己的动态规划思想更上一层楼(DP大神除外),以下是代码部分
import torch

# 维特比算法解决HMM预测问题,给定状态转移矩阵A,发射概率矩阵B,初始状态概率Pi
def viterbi(self, word_list, word2id, state2id, A, B, Pi):
    # 初始化viterbi矩阵,shape = [状态数, 序列长度],viterbi[i, j]表示当序列的第j个元素的状态为i时,前j个元素的观测链概率的最大值
    # 初始化backpointer矩阵, shape = [状态数, 序列长度],backpointer[i, j]表示当序列的第j个元素的状态为i时,使得前j个元素的观测链概率能够达到最大值的第j-1个元素的状态
    A, B, Pi = torch.log(A), torch.log(B), torch.log(Pi)
    N, seq_len = len(state2id), len(word_list)
    viterbi = torch.zeros(N, seq_len)
    backpointer = torch.zeros(N, seq_len)

    B_t = B.t()  # shape=[M, N],B[word_id]表示当前观测为word_id时各状态的概率
    start_word_id = word2id.get(word_list[0], None)
    if start_word_id is None:
        # 如果当前观测不在词表中,则假设其发射概率服从均匀分布
        b_t = torch.log(torch.ones(N) / N)
    else:
        b_t = B_t[start_word_id]

    viterbi[:, 0] = Pi + b_t  # 第一个元素的观测概率 = 初始状态概率 * 发射概率
    backpointer[:, 0] = -1  # start_word之前的元素并不存在状态,故取-1

    for step in range(1, seq_len):
        word_id = word2id.get(word_list[step], None)
        if word_id is None:
            # 如果当前观测不在词表中,则假设其发射概率服从均匀分布
            b_t = torch.log(torch.ones(N) / N)
        else:
            b_t = B_t[word_id]
        # 计算第step个观测元素的状态为state时,前step个元素的观测链概率的最大值,以及观测链概率最大时,第step-1个元素的状态
        for state_id in range(len(state2id)):
            # 前step步的最优路径必定包含前step-1步的最优路径,只需要乘上状态转移概率和发射概率,然后求max即可
            # (由于发射概率对于特定的观测与状态是相同的,所以也可以先求max,再乘上发射概率)
            max_prob, best_state_id = torch.max(viterbi[:, step - 1] + A[:, state_id], dim=0)
            viterbi[state_id, step] = max_prob + b_t[state_id]
            backpointer[state_id, step] = best_state_id

    # 终止,并且从最后一个元素开始回溯
    max_prob, best_state_id = torch.max(viterbi[:, seq_len - 1], dim=0)
    # 反向保存最优路径
    best_path = [best_state_id.item()]
    for step in range(seq_len - 1, 0, -1):
        # backpointer[i, j]中存储了第j-1到j个元素的最优路径,即第j-1个元素的最优状态
        best_state_id = backpointer[best_state_id, step]
        best_path.append(best_state_id.item())

    # 将state_id组成的逆序序列翻转,并转化为state
    assert len(best_path) == len(word_list)
    id2state = dict((id, state) for state, id in state2id.items())
    best_state_path = [id2state[id] for id in reversed(best_path)]

    return best_state_path, max_prob
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值