在学习BiLSTM+CRF进行NER任务时,处理样本数据遇到维度问题,参考GitHub代码后解决,以计算正确实例的路径分数为例
a l l _ s c o r e s : ( b a t c h _ s i z e , l e n g t h , l a b e l _ s i z e , l a b e l _ s i z e ) all\_scores: (batch\_size, length, label\_size, label\_size) all_scores:(batch_size,length,label_size,label_size)
a l l _ s c o r e s [ i ] [ j ] [ k ] [ m ] all\_scores[i][j][k][m] all_scores[i][j][k][m]表示第 i i i个样本,第 j j j个位置,由 l a b e l [ k ] ( 上 一 个 ) label[k](上一个) label[k](上一个)转移为 l a b e l [ m ] ( 当 前 标 签 ) label[m](当前标签) label[m](当前标签)的转移分数与发射分数之和。
t a g s : ( b a t c h _ s i z e , l e n g t h ) tags : (batch\_size, length) tags:(batch_size,length)
# tag_gather (batch_size, length, label_size, 1)
tag_gather = tags.view(batchSize, sentLength, 1, 1).expand(batchSize, sentLength, self.label_size, 1)
# currentTagScores[i][j][k][0] = all_scores[i][j][k][tag_gather[i][j][k][0]]
# currentTagScores 表示表示第i个样本,第j个位置上,所有label到正确tag的概率
currentTagScores = torch.gather(all_scores, 3, tag_gather).view(batchSize, -1, self.label_size)
注意 t a g _ g a t h e r [ i ] [ j ] [ k ] tag\_gather[i][j][k] tag_gather[i][j][k]对于任意 k k k都相同,都是第 i i i个样本中第 j j j个位置的标签。
# 第0个位置 由[start]到当前tag的分数
tagTransScoresBegin = currentTagScores[:, 0, self.start_idx]
# 计算中间位置的分数
tag_middle_gather = tags[:, : sentLength - 1].view(batchSize, sentLength - 1, 1)
# tagTransScoresMiddle[i][j][0] = currentTagScores[i][j][tag_middle_gather[i][j][0]]
# tagTransScoresMiddle 表示第i个样本,第j个位置的分数
tagTransScoresMiddle = torch.gather(currentTagScores[:, 1:, :], 2, tag_middle_gather).view(batchSize, -1)
最后利用masked_select函数求得相应长度的分数。
score += torch.sum(tagTransScoresMiddle.masked_select(masks[:, 1:]))