Bert-BiLSTM-CRF pytorch 代码解析-3:def _viterbi_decode

理解 github上代码:Bert-BiLSTM-CRF-pytorch
Github 相关链接: link.

这部分用于解码阶段

def _viterbi_decode(self, feats, mask=None):
    """
    Args:
        feats: size=(batch_size, seq_len, self.target_size+2)
        mask: size=(batch_size, seq_len)

    Returns:
        decode_idx: (batch_size, seq_len), viterbi decode结果
        path_score: size=(batch_size, 1), 每个句子的得分
    """
    batch_size = feats.size(0)
    seq_len = feats.size(1)
    tag_size = feats.size(-1)

    length_mask = torch.sum(mask, dim=1).view(batch_size, 1).long()
    mask = mask.transpose(1, 0).contiguous()
    ins_num = seq_len * batch_size
    feats = feats.transpose(1, 0).contiguous().view(
        ins_num, 1, tag_size).expand(ins_num, tag_size, tag_size)

    scores = feats + self.transitions.view(
        1, tag_size, tag_size).expand(ins_num, tag_size, tag_size)
    scores = scores.view(seq_len, batch_size, tag_size, tag_size)

    seq_iter = enumerate(scores)
    # record the position of the best score
    back_points = list()
    partition_history = list()
    mask = (1 - mask.long()).byte()
    try:
        _, inivalues = seq_iter.__next__()
    except:
        _, inivalues = seq_iter.next()
    partition = inivalues[:, self.START_TAG_IDX, :].clone().view(batch_size, tag_size, 1)
    partition_history.append(partition)

    for idx, cur_values in seq_iter:
        cur_values = cur_values + partition.contiguous().view(
            batch_size, tag_size, 1).expand(batch_size, tag_size, tag_size)
        partition, cur_bp = torch.max(cur_values, 1)
        partition_history.append(partition.unsqueeze(-1))

        cur_bp.masked_fill_(mask[idx].view(batch_size, 1).expand(batch_size, tag_size), 0)
        back_points.append(cur_bp)

    partition_history = torch.cat(partition_history).view(
        seq_len, batch_size, -1).transpose(1, 0).contiguous()

    last_position = length_mask.view(batch_size, 1, 1).expand(batch_size, 1, tag_size) - 1
    last_partition = torch.gather(
        partition_history, 1, last_position).view(batch_size, tag_size, 1)

    last_values = last_partition.expand(batch_size, tag_size, tag_size) + \
        self.transitions
  • 1
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值