blob_loss

import torch
from torch.nn import BCEWithLogitsLoss

def vprint(*args):
    verbose = False
    if verbose:
        print(*args)

def compute_compound_loss(criterion_dict: dict, raw_network_outputs: torch.Tensor, label: torch.Tensor,
                          blob_loss_mode=True, masked=True):
    """
    这通过循环标准dict计算复合损失!
    """
    # vprint("outputs:", outputs)
    losses = []
    for entry in criterion_dict.values():
        vprint("loss name:", entry["name"])
        criterion = entry["loss"]
        weight = entry["weight"]
        sigmoid = entry["sigmoid"]

        if blob_loss_mode is False:
            vprint("computing main loss!")
            raw_network_output, _ = torch.max(raw_network_outputs, dim=1)
            if sigmoid is True:
                sigmoid_network_outputs = torch.sigmoid(raw_network_output)
                individual_loss = criterion(sigmoid_network_outputs, label)
            else:
                individual_loss = criterion(raw_network_output, label.float())

        elif blob_loss_mode is True:
            vprint("computing blob loss!")
            if masked is True:          # this is the default blob loss
                if sigmoid is True:
                    sigmoid_network_outputs = torch.sigmoid(raw_network_outputs)
                    individual_loss = compute_blob_loss_multi(criterion, sigmoid_network_outputs, label)
                else:
                    individual_loss = compute_blob_loss_multi(criterion, raw_network_outputs, label)
            elif masked is False:  # without masking for ablation study
                if sigmoid is True:
                    sigmoid_network_outputs = torch.sigmoid(raw_network_outputs)
                    individual_loss = compute_no_masking_multi(criterion, sigmoid_network_outputs, label)
                else:
                    individual_loss = compute_no_masking_multi(criterion, raw_network_outputs, label)

        weighted_loss = individual_loss * weight
        losses.append(weighted_loss)

    loss = sum(losses)
    return loss


def compute_blob_loss_multi(criterion, network_outputs: torch.Tensor, multi_label: torch.Tensor):
    """
    1、循环我们批次中的元素
    2、循环通过每个元素的blob计算损耗,并除以blob得到元素损耗
    2.1我们需要与BCE一起考虑乙状结肠和非乙状结肠
    3、除以批次长度,得到正确的后支柱批次损失
    """
    batch_length = multi_label.shape[0]
    element_blob_loss = []
    # loop over elements
    for element in range(batch_length):
        if element < batch_length:
            end_index = element + 1
        elif element == batch_length:
            end_index = None

        element_label = multi_label[element:end_index, ...]
        element_output = network_outputs[element:end_index, ...]

        # loop through labels 循环浏览标签
        unique_labels = torch.unique(element_label)
        label_loss = []
        for ula in unique_labels:
            if ula == 0:
                vprint("ula is 0 we do nothing")
            else:
                # first we need one hot labels
                label_mask = element_label > 0
                label_mask = ~label_mask
                label_mask[element_label == ula] = 1

                the_label = element_label == ula
                the_label_int = the_label.int()

                masked_output = element_output * label_mask
                try:
                    blob_loss = criterion(masked_output, the_label_int)
                except:
                    blob_loss = criterion(masked_output, the_label.float())
                label_loss.append(blob_loss)

            if len(label_loss) != 0:
                mean_label_loss = sum(label_loss) / len(label_loss)
                element_blob_loss.append(mean_label_loss)

    # compute mean
    mean_element_blob_loss = 0
    if not len(element_blob_loss) == 0:
        mean_element_blob_loss = sum(element_blob_loss) / len(element_blob_loss)
    return mean_element_blob_loss


def compute_no_masking_multi(criterion, network_outputs: torch.Tensor, multi_label: torch.Tensor):
    """
    1、循环我们批次中的元素
    2、循环通过每个元素的blob计算损耗,并除以blob得到元素损耗
    2.1我们需要与BCE一起考虑乙状结肠和非乙状结肠
    3、除以批次长度,得到正确的后支柱批次损失
    """
    batch_length = multi_label.shape[0]
    element_blob_loss = []
    # loop over elements
    for element in range(batch_length):
        if element < batch_length:
            end_index = element + 1
        elif element == batch_length:
            end_index = None

        element_label = multi_label[element:end_index, ...]
        element_output = network_outputs[element:end_index, ...]

        # loop through labels
        unique_labels = torch.unique(element_label)
        label_loss = []
        for ula in unique_labels:
            if ula == 0:
                vprint("ula is 0 we do nothing")
            else:
                # first we need one hot labels
                the_label = element_label == ula
                the_label_int = the_label.int()
                try:
                    blob_loss = criterion(element_output, the_label_int)
                except:
                    blob_loss = criterion(element_output, the_label.float())
                label_loss.append(blob_loss)

            # compute mean
            if not len(label_loss) == 0:
                mean_label_loss = sum(label_loss) / len(label_loss)
                element_blob_loss.append(mean_label_loss)

    # compute mean
    mean_element_blob_loss = 0
    if not len(element_blob_loss) == 0:
        mean_element_blob_loss = sum(element_blob_loss) / len(element_blob_loss)

    return mean_element_blob_loss


def compute_loss(blob_loss_dict: dict, criterion_dict: dict, blob_criterion_dict: dict,
                    raw_network_outputs: torch.Tensor, binary_label: torch.Tensor, multi_label: torch.Tensor):
    """
    此函数用于计算总损失。
    它有一个全局主损失和blob损失项,对于每个连接的组件分别计算。
    binary_label是全局零件的二进制标签。
    multi_label为每个连接的组件提供了单独的整数标签。
    Example inputs should look like:
    blob_loss_dict = {
        "main_weight": 1,
        "blob_weight": 0,
    }
    criterion_dict = {
        "bce": {
            "name": "bce",
            "loss": BCEWithLogitsLoss(reduction="mean"),
            "weight": 1.0,
            "sigmoid": False,
        },
        "dice": {
            "name": "dice",
            "loss": DiceLoss(
                include_background=True,
                to_onehot_y=False,
                sigmoid=True,
                softmax=False,
                squared_pred=False,
            ),
            "weight": 1.0,
            "sigmoid": False,
        },
    }
    blob_criterion_dict = {
        "bce": {
            "name": "bce",
            "loss": BCEWithLogitsLoss(reduction="mean"),
            "weight": 1.0,
            "sigmoid": False,
        },
        "dice": {
            "name": "dice",
            "loss": DiceLoss(
                include_background=True,
                to_onehot_y=False,
                sigmoid=True,
                softmax=False,
                squared_pred=False,
            ),
            "weight": 1.0,
            "sigmoid": False,
        },
    }
    """
    main_weight = blob_loss_dict["main_weight"]
    blob_weight = blob_loss_dict["blob_weight"]

    # main loss
    if main_weight > 0:
        main_loss = compute_compound_loss(criterion_dict=criterion_dict, raw_network_outputs=raw_network_outputs,
                                                                            label=binary_label, blob_loss_mode=False)
    if blob_weight > 0:
        blob_loss = compute_compound_loss(criterion_dict=blob_criterion_dict, raw_network_outputs=raw_network_outputs,
                                                                            label=multi_label, blob_loss_mode=True)
    # final loss
    if blob_weight == 0 and main_weight > 0:
        loss = main_loss
        blob_loss = 0
    elif main_weight == 0 and blob_weight > 0:
        loss = blob_loss
        main_loss = 0  # we set this to 0
    elif main_weight > 0 and blob_weight > 0:
        loss = main_loss * main_weight + blob_loss * blob_weight
    else:
        vprint("defaulting to equal weighted blob loss")
        loss = main_loss + blob_loss

    vprint("blob loss:", blob_loss)
    vprint("main loss:", main_loss)
    vprint("effective loss:", loss)

    return loss, main_loss, blob_loss

def one_hot(net_out, target):
    shp_x = net_out.shape
    shp_y = target.shape

    if len(shp_x) != len(shp_y):
        target = target.view(target.shape[0], 1, *target.shape[1:])
    if all([i == j for i, j in zip(net_out.shape, target.shape)]):
        one_hot = target
    else:
        target = target.long()
        one_hot = torch.zeros(net_out.shape)
        if net_out.device.type == 'cuda':
            one_hot.cuda(net_out.device.idex)
        one_hot.scatter_(1, target, 1)

    return one_hot

def get_tp_fp_fn(net_output, gt, axes=None, mask=None, square=False):
    """
    copy from: https://github.com/MIC-DKFZ/nnUNet/blob/master/nnunet/training/loss_functions/dice_loss.py
    net_output must be (b, c, x, y(, z)))
    gt must be a label map (shape (b, 1, x, y(, z)) OR shape (b, x, y(, z))) or one hot encoding (b, c, x, y(, z))
    if mask is provided it must have shape (b, 1, x, y(, z)))
    :param net_output:
    :param gt:
    :param axes:
    :param mask: mask must be 1 for valid pixels and 0 for invalid pixels
    :param square: if True then fp, tp and fn will be squared before summation
    :return:
    """
    shp_x = net_output.shape
    shp_y = gt.shape

    with torch.no_grad():
        if len(shp_x) != len(shp_y):
            gt = gt.view((shp_y[0], 1, *shp_y[1:]))

        if all([i == j for i, j in zip(net_output.shape, gt.shape)]):
            # if this is the case then gt is probably already a one hot encoding
            y_onehot = gt
        else:
            gt = gt.long()
            y_onehot = torch.zeros(shp_x)
            if net_output.device.type == "cuda":
                y_onehot = y_onehot.cuda(net_output.device.index)
            y_onehot.scatter_(1, gt, 1)

    tp = net_output * y_onehot
    tn = (1 - net_output) * (1 - y_onehot)
    fp = net_output * (1 - y_onehot)
    fn = (1 - net_output) * y_onehot

    if mask is not None:
        tp = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(tp, dim=1)), dim=1)
        fp = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(fp, dim=1)), dim=1)
        fn = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(fn, dim=1)), dim=1)

    if square:
        tp = tp ** 2
        fp = fp ** 2
        fn = fn ** 2

    tp = torch.sum(tp)
    tn = torch.sum(tn)
    fp = torch.sum(fp)
    fn = torch.sum(fn)
    print("tp: {}\ntn: {}\nfp: {}\nfn: {}".format(tp, tn, fp, fn))

    return tp, tn, fp, fn

def SoftDiceLoss(x, y, loss_mask=None):
    tp, tn, fp, fn = get_tp_fp_fn(x, y, loss_mask)
    dc = (2 * tp + 1e-7) / (2 * tp + fp + fn + 1e-7)
    dc = dc.mean()
    return 1 - dc

if __name__ == '__main__':
    net_out = torch.tensor(
        [[[[0.8, 0.8, 0.8, 0.1, 0.1],
           [0.8, 0.8, 0.8, 0.1, 0.1],
           [0.8, 0.8, 0.8, 0.1, 0.1],
           [0.1, 0.1, 0.1, 0.2, 0.2],
           [0.1, 0.1, 0.1, 0.2, 0.2]],

          [[0.05, 0.05, 0.05, 0.8, 0.8],
           [0.05, 0.05, 0.05, 0.8, 0.8],
           [0.05, 0.05, 0.05, 0.8, 0.8],
           [0.04, 0.04, 0.04, 0.05, 0.05],
           [0.04, 0.04, 0.04, 0.05, 0.05]],

          [[0.04, 0.04, 0.04, 0.03, 0.03],
           [0.04, 0.04, 0.04, 0.03, 0.03],
           [0.04, 0.04, 0.04, 0.03, 0.03],
           [0.7, 0.7, 0.7, 0.15, 0.15],
           [0.7, 0.7, 0.7, 0.15, 0.15]],

          [[0.09, 0.09, 0.09, 0.07, 0.07],
           [0.09, 0.09, 0.09, 0.07, 0.07],
           [0.09, 0.09, 0.09, 0.07, 0.07],
           [0.16, 0.16, 0.16, 0.6, 0.6],
           [0.16, 0.16, 0.16, 0.6, 0.6]]]]
    )
    target = torch.tensor(
        [[[1, 1, 1, 2, 2],
          [1, 1, 1, 2, 2],
          [1, 1, 1, 2, 2],
          [3, 3, 3, 0, 0],
          [3, 3, 3, 0, 0]]]
    )
    binary_label = torch.tensor(
        [[[1, 1, 1, 1, 1],
          [1, 1, 1, 1, 1],
          [1, 1, 1, 1, 1],
          [1, 1, 1, 0, 0],
          [1, 1, 1, 0, 0]]]
    )
    multi_label = one_hot(net_out, target)
    blob_loss_dict = {
        "main_weight": 1,
        "blob_weight": 1,
    }

    criterion_dict = {
        "bce": {
            "name": "bce",
            "loss": BCEWithLogitsLoss(reduction="mean"),
            "weight": 1.0,
            "sigmoid": False,
        },
        "dice": {
            "name": "dice",
            "loss": SoftDiceLoss,
            "weight": 1.0,
            "sigmoid": False,
        },
    }

    blob_criterion_dict = {
        "bce": {
            "name": "bce",
            "loss": BCEWithLogitsLoss(reduction="mean"),
            "weight": 1.0,
            "sigmoid": False,
        },
        "dice": {
            "name": "dice",
            "loss": SoftDiceLoss,
            "weight": 1.0,
            "sigmoid": False,
        },
    }

    out = compute_loss(blob_loss_dict=blob_loss_dict, criterion_dict=criterion_dict, blob_criterion_dict=blob_criterion_dict,
                       raw_network_outputs=net_out, binary_label=binary_label, multi_label=multi_label)

    print(out)
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

后天...

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值