import torch import torch.nn as nn import math print('不使用softmax计算CrossEntropyLoss:') criterion = nn.CrossEntropyLoss() output = torch.randn(3, 5, requires_grad=True) label = torch.empty(3, dtype=torch.long).random_(5) loss = criterion(output, label) print("网络输出为3个5类:") print(output) print("要计算loss的类别:") print(label) print("计算loss的结果:") print(loss) first = [0, 0, 0] for i in range(3): first[i] = -output[i][label[i]] second = [0, 0, 0] for i in range(3): for j in range(5): second[i] += math.exp(output[i][j]) res = 0 for i in range(3): res += (first[i] + math.log(second[i])) print("自己的计算结果:") print(res / 3) print('~~' * 50) print('CrossEntropyLoss包含sotfmax,下面使用softmax计算CrossEntropyLoss:') print('softmax后输出:') softmax = nn.Softmax(dim=1) softmax_output = softmax(output) print(softmax_output) sum_soft = 0 for i in range(3): sum_soft += -math.log(softmax_output[i][label[i]]) print('softmax后计算结果:') print(sum_soft/3) print("计算loss的结果:") print(loss)
CrossEntropyLoss计算过程
最新推荐文章于 2024-06-14 09:50:01 发布