理解 github上代码:Bert-BiLSTM-CRF-pytorch
Github 相关链接: link.
这部分用于解码阶段
def _viterbi_decode(self, feats, mask=None):
"""
Args:
feats: size=(batch_size, seq_len, self.target_size+2)
mask: size=(batch_size, seq_len)
Returns:
decode_idx: (batch_size, seq_len), viterbi decode结果
path_score: size=(batch_size, 1), 每个句子的得分
"""
batch_size = feats.size(0)
seq_len = feats.size(1)
tag_size = feats.size(-1)
length_mask = torch.sum(mask, dim=1).view(batch_size, 1).long()
mask = mask.transpose(1, 0).contiguous()
ins_num = seq_len * batch_size
feats = feats.transpose(1, 0).contiguous().view(
ins_num, 1, tag_size).expand(ins_num, tag_size, tag_size)
scores = feats + self.transitions.view(
1, tag_size, tag_size).expand(ins_num, tag_size, tag_size)
scores = scores.view(seq_len, batch_size, tag_size, tag_size)
seq_iter = enumerate(scores)
# record the position of the best score
back_points = list()
partition_history = list()
mask = (1 - mask.long()).byte()
try:
_, inivalues = seq_iter.__next__()
except:
_, inivalues = seq_iter.next()
partition = inivalues[:, self.START_TAG_IDX, :].clone().view(batch_size, tag_size, 1)
partition_history.append(partition)
for idx, cur_values in seq_iter:
cur_values = cur_values + partition.contiguous().view(
batch_size, tag_size, 1).expand(batch_size, tag_size, tag_size)
partition, cur_bp = torch.max(cur_values, 1)
partition_history.append(partition.unsqueeze(-1))
cur_bp.masked_fill_(mask[idx].view(batch_size, 1).expand(batch_size, tag_size), 0)
back_points.append(cur_bp)
partition_history = torch.cat(partition_history).view(
seq_len, batch_size, -1).transpose(1, 0).contiguous()
last_position = length_mask.view(batch_size, 1, 1).expand(batch_size, 1, tag_size) - 1
last_partition = torch.gather(
partition_history, 1, last_position).view(batch_size, tag_size, 1)
last_values = last_partition.expand(batch_size, tag_size, tag_size) + \
self.transitions