理解 github上代码:Bert-BiLSTM-CRF-pytorch
Github 相关链接: link.
neg_log_likelihood_loss = forward_score - gold_score
这部分应该是为了计算所有路径的分数(forward_score )
def _forward_alg(self, feats, mask=None):
"""
Do the forward algorithm to compute the partition function (batched).
Args:
feats: size=(batch_size, seq_len, self.target_size+2)
mask: size=(batch_size, seq_len)
Returns:
xxx
"""
batch_size = feats.size(0)
seq_len = feats.size(1)
tag_size = feats.size(-1)
# 1. mask 转置 后 shape 为: (seq_len, batch),
# feats 原先 shape=(batch_size, seq_len, tag_size)
# 先转置: (seq_len, batch_size, tag_size)
# view: (seq_len*batch_size, 1, tag_size)
# 然后在 -2 维度复制: (seq_len*batch_size, [tag_size], tag_size)
mask = mask.transpose(1, 0).contiguous()
ins_num = batch_size * seq_len
feats = feats.transpose(1, 0).contiguous().view(
ins_num, 1, tag_size).expand(ins_num, tag_size, tag_size)
# 2. scores: LSTM所有时间步的输出 feats 先加上 转移分数
scores = feats + self.transitions.view(
1, tag_size, tag_size).expand(ins_num, tag_size, tag_size)
scores = scores.view(seq_len, batch_size, tag_size, tag_size)
seq_iter = enumerate(scores)
# seq_iter: t=0 开始的LSTM所有时间步迭代输出
# inivalues: t=1 开始的LSTM所有时间步迭代输出
try:
_, inivalues = seq_iter.__next__()
except:
_, inivalues = seq_iter.next()
# 2. 计算 a 在 t=0 时刻的初始值
partition = inivalues[:, self.START_TAG_IDX, :].clone().view(batch_size, tag_size, 1)
# 3. 迭代计算 a (即partition ) 在 t=1,2,。。。更新的值
for idx, cur_values in seq_iter: # fro idx = 1,2,3..., cur_values是LSTM输出+转移分数的值
cur_values = cur_values + partition.contiguous().view(
batch_size, tag_size, 1).expand(batch_size, tag_size, tag_size)
cur_partition = log_sum_exp(cur_values, tag_size)
mask_idx = mask[idx, :].view(batch_size, 1).expand(batch_size, tag_size)
masked_cur_partition = cur_partition.masked_select(mask_idx.byte())
if masked_cur_partition.dim() != 0:
# 将mask_idx中值为1元素对应的masked_cur_partition中位置的元素复制到本partition中。
# mask应该有和partition相同数目的元素。
# 即 mask 部分的 partition值不再更新
mask_idx = mask_idx.contiguous().view(batch_size, tag_size, 1)
partition.masked_scatter_(mask_idx.byte(), masked_cur_partition)
cur_values = self.transitions.view(1, tag_size, tag_size).expand(
batch_size, tag_size, tag_size) + partition.contiguous().view(
batch_size, tag_size, 1).expand(batch_size, tag_size,