交叉损失熵函数
在深度学习进行分类任务时经常用到交叉损失熵函数:首先定义logic表示余弦相似度,labels表示真实标签。可以直接使用交叉熵损失函数
import torch.nn as nn
Loss = nn.CrossEntropyLoss(logic, labels)
其中logic是一个NXC大小的数组,N表示有多少个样本(BatchSize),C表示类别总数。例如,
表示有三个样本,四个类别。
而labels是一个一维数组,大小为N。表示N个样本所对应的真实类别。如label=[3,0,1],表示一个有三个样本,第一个样本对应的真实标签为3,第二个真实标签为0,第三个真实标签为1.