什么是 logits?
logits
是模型输出的未归一化预测值,通常是全连接层的输出。在分类任务中,logits 的形状通常为 (batch_size, num_labels)
,其中 batch_size
是一个批次中的样本数,num_labels
是分类任务中的类别数。
logits
是模型的输出。假设logits
的形状为(batch_size, num_labels)
,例如(32, 3)
,表示每个批次有32
个样本,每个样本有3
个类别的预测值。
什么是交叉熵损失函数?
交叉熵损失函数(Cross-Entropy Loss
)是一种常用于分类任务的损失函数。它衡量的是预测分布与真实分布之间的差异。具体而言,它会计算每个样本的预测类别与真实类别之间的距离,然后取平均值。
在 PyTorch 中,交叉熵损失函数可以通过 torch.nn.CrossEntropyLoss
来实现。该函数结合了 LogSoftmax
和 NLLLoss
两个操作,适用于未归一化的 logits
。
具体示例
假设有一个分类任务,模型的输出和标签如下:
logits