很简单,自定义了一个 dataloader,出现以下不易察觉 bug
inputs 维度:[bs, 4],这个没问题
labels 维度:正确应该是 [bs, 1],但是 dataloader 出来是 [bs]

模型的 outputs 维度:[bs, 1]

如果用 torch.mean(torch.abs(labels - outputs)) 计算 L1 Loss / MAE

由于 pytorch 的广播机制,torch.abs(labels - outputs) 变成了一个 bs * bs 的矩阵,然后计算了这个矩阵的均值,直接变成对比学习了

怎么解决这个问题呢:

在传入 dataloader 时,就令 labels 的维度是 [n, 1],而不是 [n]

怎么避免这个问题呢:

使用 F.l1_loss(),会发现这样的情况会直接报错,因为它会检查两个 tensor 的维度是一致的:

所以在自己手搓 Loss 时,需要检查维度一致,避免广播。

def l1_loss(input, target, size_average=None, reduce=None, reduction='mean'):
    # type: (Tensor, Tensor, Optional[bool], Optional[bool], str) -> Tensor
    r"""l1_loss(input, target, size_average=None, reduce=None, reduction='mean') -> Tensor

    Function that takes the mean element-wise absolute value difference.

    See :class:`~torch.nn.L1Loss` for details.
    """
    if not torch.jit.is_scripting():
        tens_ops = (input, target)
        if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops):
            return handle_torch_function(
                l1_loss, tens_ops, input, target, size_average=size_average, reduce=reduce,
                reduction=reduction)
    if not (target.size() == input.size()):
        warnings.warn("Using a target size ({}) that is different to the input size ({}). "
                      "This will likely lead to incorrect results due to broadcasting. "
                      "Please ensure they have the same size.".format(target.size(), input.size()),
                      stacklevel=2)
    if size_average is not None or reduce is not None:
        reduction = _Reduction.legacy_get_string(size_average, reduce)
    if target.requires_grad:
        ret = torch.abs(input - target)
        if reduction != 'none':
            ret = torch.mean(ret) if reduction == 'mean' else torch.sum(ret)
    else:
        expanded_input, expanded_target = torch.broadcast_tensors(input, target)
        ret = torch._C._nn.l1_loss(expanded_input, expanded_target, _Reduction.get_enum(reduction))
    return ret
  • 1.
  • 2.
  • 3.
  • 4.
  • 5.
  • 6.
  • 7.
  • 8.
  • 9.
  • 10.
  • 11.
  • 12.
  • 13.
  • 14.
  • 15.
  • 16.
  • 17.
  • 18.
  • 19.
  • 20.
  • 21.
  • 22.
  • 23.
  • 24.
  • 25.
  • 26.
  • 27.
  • 28.
  • 29.