function loss = cross_entropy(yhat, y)
% yhat: N*K matrix, 预测输出的概率分布
% y: N*K matrix, 真实标签(one-hot编码)
% loss: 交叉熵损失值
% 离散化计算交叉熵
tmp = y .* log(yhat);
loss = -1/N * sum(sum(tmp));
batchsize大于1时,计算交叉熵的matlab代码,要求使用矩阵运算,而不是for循环
最新推荐文章于 2024-10-31 09:22:09 发布