torch.nn.functional.cross_entropy使用
介绍
官网介绍:torch.nn.functional.cross_entropy
F.cross_entropy
是用于计算交叉熵损失函数的函数。它的输出是一个表示给定输入的损失值的张量。F.cross_entropy
函数与nn.CrossEntropyLoss
类是相似的,但前者更适合于控制更多的细节。
由于内部已经使用SoftMax
函数处理,故两者的输入都不需要使用SoftMax
处理。
函数原型为:F.cross_entropy(input, target, weight=None, size_average=None, ignore_index=-100, reduce=None, reduction='mean')
。
参数说明:
- input:(N, C)形状的张量,其中N为Batch_size,C为类别数。该参数对应于神经网络最后一个全连接层输出的未经Softmax处理的结果。
- target:一个大小为(N,)张量,其值是0 <= targets[i] <= C-1的元素,其中C是类别数。该参数包含一组给定的真实标签(ground truth)。
- weight:采用类别平衡的加权方式计算损失值。可以传入一个大小为(C,)张量,其中weight[j]是类别j的权重。默认为None。
- size_average:弃用,可以忽略。该参数已经被reduce参数取代了。
- ignore_index:指定被忽略的目标值的索引。如果目标值等于该索引,则不计算该样本的损失。默认值为-100,即不忽略任何目标值。
- reduce:指定返回的损失值的方式。可以是“None”(不返回损失值)、“mean”(返回样本损失值的平均值)和“sum”(返回样本损失值的总和)。默认值为“mean”。
- reduction:与reduce参数等价。表示返回的损失值的方式。默认值为“mean”。
计算过程
首先,我们有以下输入:
- predictions:一个2维tensor,表示模型的预测结果。它的形状是(2, 3),其中2是样本数量,3是类别数量。每一行对应一个样本的预测结果,每个元素表示该类别的概率。
predictions = torch.tensor([[0.304, 0.333, 0.363], [0.442, 0.279, 0.279]])
- labels:一个1维tensor,表示样本的真实标签。它的形状是(2,),其中2是样本数量。
labels = torch.tensor([2, 0])
在计算交叉熵损失之前,F.cross_entropy函数会对predictions进行softmax操作,以确保每行的元素和为1,并将其视为概率分布。在我们的示例中,predictions经过softmax操作后的结果为:
pred_softmax = torch.tensor([[0.3236, 0.3331, 0.3433],
[0.3705, 0.3148, 0.3148]])
接下来,我们定义了一个权重向量weights,用于指定每个类别的权重。在我们的示例中,权重向量为:
weights = torch.tensor([1.0, 2.0, 3.0])
然后,F.cross_entropy函数计算交叉熵损失。它首先计算每个样本的交叉熵损失,然后将所有样本的损失求平均。
对于第一个样本,它的预测结果为[0.304, 0.333, 0.363],真实标签为2。根据交叉熵损失的定义,我们可以计算出它的损失为:
-weights[2] * math.log(pred_softmax[0, 2])
其中log表示自然对数。在我们的示例中,根据权重和预测结果的计算,第一个样本的损失为:
-3.0 * math.log(0.3433) ≈ -3.0 * (-1.06915) ≈ 3.20745
对于第二个样本,它的预测结果为[0.442, 0.279, 0.279],真实标签为0。根据交叉熵损失的定义,我们可以计算出它的损失为:
-weights[0] * math.log(pred_softmax[1, 0])
在我们的示例中,根据权重和预测结果的计算,第二个样本的损失为:
-1.0 * math.log(0.3705) ≈ -1.0 * (-0.9929) ≈ 0.9929
最后,若reduction=‘none’,则不计算平均值;若reduction=‘mean’,则计算所有样本的损失求平均,得到最终的带权重的交叉熵损失值:
(3.20745 + 0.9929) / (1 + 3) ≈ 1.0500875
示例代码:
以下是一个示例代码,展示了如何使用weight参数来指定每个类别的权重:
import torch
import torch.nn.functional as F
import math
# 创建模型预测结果和真实标签
predictions = torch.tensor([[0.2, 0.3, 0.5], [0.8, 0.1, 0.1]])
labels = torch.tensor([2, 0])
# 定义每个类别的权重
weights = torch.tensor([1.0, 2.0, 3.0])
# 使用F.cross_entropy计算带权重的交叉熵损失
loss = F.cross_entropy(predictions, labels, weight=weights)
print(loss) # tensor(0.8773)
# 测试计算过程
pred = F.softmax(predictions, dim=1)
loss2 = -(3 * math.log(pred[0,2]) + math.log(pred[1,0]))/4 # 4 = 1+3 对应权重之和
print(loss2) # 0.8773049571540321
在这个例子中,我们创建了一个二维tensor predictions,表示模型的预测结果。每一行对应一个样本的预测结果,每个元素表示该类别的概率。我们还创建了一个一维tensor labels,表示样本的真实标签。
然后,我们定义了一个一维tensor weights,用于指定每个类别的权重。在计算交叉熵损失时,权重会被应用到每个类别上。
最后,使用F.cross_entropy函数计算了predictions和labels之间的交叉熵损失,同时考虑了类别权重。得到了一个标量tensor loss,表示模型的带权重的损失值。
通过调整权重,我们可以提高或降低某些类别对模型损失的贡献,从而影响模型的训练过程。权重的选择应根据具体问题和需求进行调整。
应用于多分类损失示例
在PyTorch中,计算多分类的交叉熵损失通常使用 torch.nn.CrossEntropyLoss
类。这个类结合了 log_softmax
和 NLLLoss
(负对数似然损失),因此可以直接用于多分类问题。
假设你有一个多分类输出 out
,其形状为 (batch_size, num_classes, height, width)
,以及对应的标签 target
,其形状为 (batch_size, height, width)
,其中 target
中的每个元素是一个类别的索引(从 0
到 num_classes-1
)。
1. 定义损失函数
首先,你需要定义一个 CrossEntropyLoss
对象:
import torch
import torch.nn as nn
# 定义损失函数
criterion = nn.CrossEntropyLoss()
2. 计算损失
接下来,你可以使用这个损失函数来计算损失。注意,CrossEntropyLoss
期望输入的形状为 (batch_size, num_classes, ...)
,而标签的形状为 (batch_size, ...)
。
# 假设 out 是你的多分类输出,target 是你的标签
out = torch.randn(batch_size, num_classes, height, width) # 示例数据
target = torch.randint(0, num_classes, (batch_size, height, width)) # 示例标签
# 计算损失
loss = criterion(out, target)
3.生成每个类别的掩码
import torch
# 假设 out 是你的多分类输出
out = torch.randn(1, n, 256, 256) # 示例数据
# 获取每个像素的类别索引
class_indices = torch.argmax(out, dim=1) # 形状为 (1, 256, 256)
# 初始化掩码张量
masks = torch.zeros(n, 256, 256, dtype=torch.uint8) # 形状为 (n, 256, 256)
# 为每个类别生成掩码
for i in range(n):
masks[i] = (class_indices == i).squeeze(0) # 形状为 (256, 256)
# 可选:转换为布尔掩码
masks = masks.bool()
4. 注意事项
out
的形状应为(batch_size, num_classes, height, width)
。target
的形状应为(batch_size, height, width)
,并且每个元素是一个类别的索引(从0
到num_classes-1
)。CrossEntropyLoss
会自动对out
进行log_softmax
操作,因此你不需要在out
上手动应用softmax
。
补充:F.binary_cross_entropy和F.binary_cross_entropy_with_logits区别
在 PyTorch 中,torch.nn.functional
模块提供了两种用于计算二元交叉熵损失的函数:F.binary_cross_entropy
和 F.binary_cross_entropy_with_logits
。它们的主要区别在于输入的格式和处理方式。
1. F.binary_cross_entropy
F.binary_cross_entropy
直接计算二元交叉熵损失,适用于已经经过 Sigmoid 函数处理的预测结果。
输入要求:
input
: 形状为(N, *)
的张量,表示模型的预测结果,取值范围在 [0, 1] 之间。target
: 形状为(N, *)
的张量,表示真实标签,取值为 0 或 1。
示例代码:
import torch
import torch.nn.functional as F
# 假设 result 和 label 已经定义
result = torch.rand((3, 256, 256)) # 预测结果,shape 为 (3, 256, 256),取值范围在 [0, 1] 之间
label = torch.randint(0, 2, (3, 256, 256)).float() # 标签,shape 为 (3, 256, 256),取值为 0 或 1
# 计算损失
loss = F.binary_cross_entropy(result, label)
print(loss)
2. F.binary_cross_entropy_with_logits
F.binary_cross_entropy_with_logits
结合了 Sigmoid 函数和二元交叉熵损失,适用于未经过 Sigmoid 函数处理的预测结果。它通常用于模型的最后一层输出是未经激活函数处理的情况。
输入要求:
input
: 形状为(N, *)
的张量,表示模型的预测结果,取值范围可以是任意实数。target
: 形状为(N, *)
的张量,表示真实标签,取值为 0 或 1。
示例代码:
import torch
import torch.nn.functional as F
# 假设 logits 和 label 已经定义
logits = torch.randn((3, 256, 256)) # 预测结果,shape 为 (3, 256, 256),取值范围可以是任意实数
label = torch.randint(0, 2, (3, 256, 256)).float() # 标签,shape 为 (3, 256, 256),取值为 0 或 1
# 计算损失
loss = F.binary_cross_entropy_with_logits(logits, label)
print(loss)
主要区别:
-
输入格式:
F.binary_cross_entropy
要求输入的预测结果已经经过 Sigmoid 函数处理,取值范围在 [0, 1] 之间。F.binary_cross_entropy_with_logits
适用于未经过 Sigmoid 函数处理的预测结果,取值范围可以是任意实数。
-
内部处理:
F.binary_cross_entropy_with_logits
内部会自动应用 Sigmoid 函数,然后再计算二元交叉熵损失。
使用场景:
- 如果你的模型最后一层已经应用了 Sigmoid 函数,使用
F.binary_cross_entropy
。 - 如果你的模型最后一层没有应用 Sigmoid 函数,使用
F.binary_cross_entropy_with_logits
。
注意事项:
F.binary_cross_entropy_with_logits
通常在数值稳定性上优于手动应用 Sigmoid 函数后再计算F.binary_cross_entropy
,因为它避免了数值不稳定性问题。
补充:交叉熵通俗理解
交叉熵损失函数(Cross-Entropy Loss)是一种常用的损失函数,主要用于衡量分类问题中模型预测结果与真实标签之间的差异。
为了理解交叉熵损失函数,我们可以从两个方面进行解释:
-
信息论解释:交叉熵损失函数可以被看作是真实标签的信息量在模型预测概率分布上的平均值。信息论中的熵(Entropy)是衡量一个随机变量的不确定性的度量,而交叉熵则是真实分布与预测分布之间的差异。当预测分布与真实分布越接近时,交叉熵的值越小,表示模型的预测结果越准确。
-
最大似然估计解释:交叉熵损失函数可以被看作是最大似然估计的负对数似然。最大似然估计是一种统计方法,用于根据观测数据来估计概率分布的参数。交叉熵损失函数在分类问题中,将真实标签视为观测数据,并通过最小化交叉熵来寻找最适合的模型参数,使得模型能够最大程度地拟合真实标签的分布。
简而言之,交叉熵损失函数可以衡量模型预测结果与真实标签之间的差异,并通过最小化交叉熵来寻找最优的模型参数。当模型的预测结果与真实标签越接近时,交叉熵的值越小,表示模型的预测越准确。交叉熵损失函数在分类问题中被广泛使用,并被证明是一种有效的优化目标。