简介
为了更好地理解 pytorch 的 CrossEntropyLoss,于是打算进行简单的实现。
官方文档:
https://pytorch.org/docs/stable/nn.html?highlight=crossentropyloss#torch.nn.CrossEntropyLoss
官网 loss 的公式:
x 的维度是 (batch_size, C)
class 的维度是 (batch_size)
(这里的 C 是分类的个数)
核心代码
写了一个类,利用 numpy 进行实现
input 对应的是上面公式的 x,target 对应的是 class
核心代码如下(实现公式):
batch_loss = 0.
for i in range(input.shape[0]):
numerator = np.exp(input[i, target[i]]) # 分子
denominator = np.sum(np.exp(input[i, :])) # 分母
loss = -np.log(numerator / denominator)
batch_loss += loss
上面公式中,两个输入 x,class 分别对应代码里的 input,target。
每一个循环,