Bert-BiLSTM-CRF pytorch 代码解析-1:def _forward_alg(self, feats, mask=None)

理解 github上代码:Bert-BiLSTM-CRF-pytorch
Github 相关链接: link.

neg_log_likelihood_loss = forward_score - gold_score
这部分应该是为了计算所有路径的分数(forward_score )

    def _forward_alg(self, feats, mask=None):
        """
        Do the forward algorithm to compute the partition function (batched).

        Args:
            feats: size=(batch_size, seq_len, self.target_size+2)
            mask: size=(batch_size, seq_len)

        Returns:
            xxx
        """
        batch_size = feats.size(0)
        seq_len = feats.size(1)
        tag_size = feats.size(-1)

        # 1. mask 转置 后 shape 为: (seq_len, batch), 
        #    feats 原先 shape=(batch_size, seq_len, tag_size) 
        #          先转置:    (seq_len, batch_size, tag_size)
        #          view:  (seq_len*batch_size, 1, tag_size)
        #          然后在 -2 维度复制: (seq_len*batch_size, [tag_size], tag_size)
        mask = mask.transpose(1, 0).contiguous()
        ins_num = batch_size * seq_len
        feats = feats.transpose(1, 0).contiguous().view(
            ins_num, 1, tag_size).expand(ins_num, tag_size, tag_size)

        # 2. scores: LSTM所有时间步的输出 feats 先加上 转移分数
        scores = feats + self.transitions.view(
            1, tag_size, tag_size).expand(ins_num, tag_size, tag_size)
        scores = scores.view(seq_len, batch_size, tag_size, tag_size)
        seq_iter = enumerate(scores) 
        # seq_iter: t=0 开始的LSTM所有时间步迭代输出
        # inivalues: t=1 开始的LSTM所有时间步迭代输出
        try:
            _, inivalues = seq_iter.__next__()
        except:
            _, inivalues = seq_iter.next()

        # 2. 计算 a 在 t=0 时刻的初始值
        partition = inivalues[:, self.START_TAG_IDX, :].clone().view(batch_size, tag_size, 1)
        # 3. 迭代计算 a (即partition ) 在 t=1,2,。。。更新的值
        for idx, cur_values in seq_iter: # fro idx = 1,2,3..., cur_values是LSTM输出+转移分数的值
            cur_values = cur_values + partition.contiguous().view(
                batch_size, tag_size, 1).expand(batch_size, tag_size, tag_size)
            cur_partition = log_sum_exp(cur_values, tag_size)
            mask_idx = mask[idx, :].view(batch_size, 1).expand(batch_size, tag_size)
            masked_cur_partition = cur_partition.masked_select(mask_idx.byte())
            if masked_cur_partition.dim() != 0:
                # 将mask_idx中值为1元素对应的masked_cur_partition中位置的元素复制到本partition中。
                # mask应该有和partition相同数目的元素。
                # 即 mask 部分的 partition值不再更新
                mask_idx = mask_idx.contiguous().view(batch_size, tag_size, 1)
                partition.masked_scatter_(mask_idx.byte(), masked_cur_partition)
        
        cur_values = self.transitions.view(1, tag_size, tag_size).expand(
                batch_size, tag_size, tag_size) + partition.contiguous().view(
                batch_size, tag_size, 1).expand(batch_size, tag_size,
  • 1
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值