CrossEntropyLoss 等价于 softmax+log+NLLLoss
LogSoftmax等价于softmax+log
可用于文本分类、序列标注等计算损失
使用方法:
# 首先定义该类
loss = torch.nn.CrossEntropyLoss()
#然后传参进去
loss(input, target)
input维度为N*C,是网络生成的值,N为batch_size,C为类别数;
target维度为N,是标注值,非one-hot类型的值;
input = torch.randn(4,3)
target = torch.tensor([0,1,1,2]) #必须为Long类型,是类别的序号
cross_entropy_loss = nn.CrossEntropyLoss()
loss = cross_entropy_loss(input, target)
# 对于序列标注来说,需要reshape一下
input = torch.randn(2,4,3) # 2为batch_size, 4为seq_length,3为类别数
input = input.view(-1,3) # 一共8个token
target = torch.tensor([[0,1,1,2], [2,3,1,0]])
target = target.view(-1)
loss = cross_entropy_loss(input, target) # reduction='mean',默认为mean;