nn.CrossEntropyLoss()
的参数
torch.nn.CrossEntropyLoss(weight=None, size_average=None,
ignore_index=-100, reduce=None, reduction=‘mean’)
weight
:不必多说,这就是各class的权重。所以它的值必须满足两点:- type = torch.Tensor
- weight.shape = tensor(1, class_num)
size_average
、reduce
:都要被弃用了,直接看reduction
就行reduction
:结果的规约方式,取值空间为{'mean', 'none', 'sum}
。由于你传入nn.CrossEntropyLoss()
的输入是一个batch,那么按理说得到的交叉熵损失应该是batch
个loss。当前默认的处理方式是,对batch
个损失取平均;也可以选择不做规约;或者将batch
个损失取加和;ignore_index
:做交叉熵计算时,若输入为ignore_index
指定的数值,则该数值会被忽略,不参与交叉熵计算。
BERT中是怎么做到只计算[MASK]token的CrossEntropyLoss的?
nn.CrossEntropyLoss()
的ignore_index
参数在BERT的mask中用到了。由于BERT中其中一个预训练任务是MLM,只有15%的token被[MASK],所以说只有这15%的词会参与交叉熵loss的计算,其他85%不参与loss计算的槽位,就使用-1填充;而参与loss计算的槽位,会使用在 vocab.txt
里提前定义好的原始token对应的index表示,这些index都是大于101
([CLS])的,所以计算时不会被ignore