Balanced-MixUp的自定义的交叉熵损失函数
def cross_entropy_loss(input: torch.Tensor,
target: torch.Tensor) -> torch.Tensor:
return -(input.log_softmax(dim=-1) * target).sum(dim=-1).mean()
官方的 nn.CrossEntropyLoss()
import torch.nn as nn
criterion = nn.CrossEntropyLoss()
主要区别
-
输入类型和形状:
- 自定义函数:假设
input
是一个具有 logits 的张量,形状为(batch_size, num_classes)
,而target
是独热编码的标签,形状也是(batch_size, num_classes)
。 - 官方函数:假设
input
是一个具有 logits 的张量,形状为(batch_size, num_classes)
,而target
是一个长整型张量,形状为(batch_size)
,每个值表示对应样本的类索引。
- 自定义函数:假设
mixup的相关操作是怎样将类索引标签变成独热编码的
mixed_labels = (1 - lam) * F.one_hot(labels, n_classes) + lam * F.one_hot(balanced_labels, n_classes)