1. 交叉熵(torch.nn.CrossEntropyLoss)
参数
class CrossEntropyLoss(_WeightedLoss):
def __init__(self, weight=None, size_average=True, ignore_index=-100, reduce=True):
pass
def forward(self, input, target):
pass
两个参数:
input:形状:NxC,其中C为类别数量;
target:形状:N,一维张量。(CrossEntropyLoss()会把target自动变成one-hot形式,若例子的样本标签是1(从0开始计算)。那么转换成的one−hot编码就是[01000],target也变成了 NxC 维张量。
注意:target必须是torch.long类型。
示例
import torch
import torch.nn as nn
import math
criterion = nn.CrossEntropyLoss()
output = torch.randn(3, 5, requires_grad=True)
label = torch.empty(3, dtype=torch.long).random_(5)
loss = criterion(output, label)
print("网络输出为3个5类:")
print(output)
print("要计算loss的类别:")
print(label)
print("计算loss的结果:")
print(loss)
输出:
可以看到,loss输出的是平均值。