boundary_loss

import torch
import torch.nn as nn
from losses_pytorch.ND_Crossentropy import CrossEntropy, TopkLoss
from scipy.ndimage import distance_transform_edt
import numpy as np
from skimage import segmentation as skimage_seg

def softamx_helper(x):
    rpt = [1 for i in range(len(x.shape))]
    rpt[1] = x.shape[1]
    x_max = x.max(1, keepdim=True)[0].repeat(*rpt)
    e_x = torch.exp(x - x_max)
    softmax = e_x / e_x.sum(dim=1, keepdim=True).repeat(*rpt)
    return softmax

def sum_tensor(inp, axes, keepdim=False):
    if keepdim:
        for ax in axes:
            inp = inp.sum(int(ax), keepdim=True)
    else:
        for ax in axes:
            inp = inp.sum(int(ax), keepdim=False)
    return inp

def tp_tn_fp_fn(net_out, target, axes=None, mask=None, square=False):
    num_class = net_out.shape[1]
    if axes is None:
        axes = tuple(range(2, len(net_out.size())))

    shp_x = net_out.shape
    shp_y = target.shape

    with torch.no_grad():
        if len(shp_x) != len(shp_y):
            target = target.view(shp_y[0], 1, *shp_y[1:])
        if all([i == j for i, j in zip(shp_x, shp_y)]):
            one_hot = target
        else:
            idx = target.long()
            one_hot = torch.zeros(shp_x)
            one_hot = one_hot.scatter_(1, idx, 1)

    tp = net_out * one_hot
    tn = (1 - net_out) * (1 - one_hot)
    fp = net_out * (1 - one_hot)
    fn = (1 - net_out) * one_hot

    tp = sum_tensor(tp, axes, keepdim=True).view(-1, num_class)
    tn = sum_tensor(tn, axes, keepdim=True).view(-1, num_class)
    fp = sum_tensor(fp, axes, keepdim=True).view(-1, num_class)
    fn = sum_tensor(fn, axes, keepdim=True).view(-1, num_class)

    return tp, tn, fp, fn

def boudary_weight(target, out_shape):
    """
    compute the signed distance map of binary mask
    input: segmentation, shape = (batch_size, x, y, z)
    output: the Signed Distance Map (SDM)
    sdf(x) = 0; x in segmentation boundary
             -inf|x-y|; x in segmentation
             +inf|x-y|; x out of segmentation
    """
    target = target.astype(np.uint8)
    weight = np.zeros(out_shape)

    for b in range(out_shape[0]):  # batch size
        for c in range(0, out_shape[1]):  # channel
            posmask = target[b][c].astype(np.bool)
            if posmask.any():
                negmask = ~posmask
                posdis = distance_transform_edt(posmask)
                negdis = distance_transform_edt(negmask)
                boundary = skimage_seg.find_boundaries(posmask, mode='inner').astype(np.uint8)
                get = negdis - posdis
                get[boundary == 1] = 0
                weight[b][c] = get

    return weight

class BDLoss(nn.Module):
    def __init__(self):
        super(BDLoss, self).__init__()
        # self.do_bg = do_bg

    def forward(self, net_out, target):
        """
        net_out: (batch_size, class, x,y,z)
        target: ground truth, shape: (batch_size, 1, x,y,z)
        bound_weight: precomputed distance map, shape (batch_size, class, x,y,z)
        """
        net_out = torch.softmax(net_out, dim=1)
        shp_x = net_out.shape
        shp_y = target.shape

        with torch.no_grad():
            if len(shp_x) != len(shp_y):
                target = target.view((shp_y[0], 1, *shp_y[1:]))
            if all([i == j for i, j in zip(shp_x, shp_y)]):
                one_hot = target
            else:
                target = target.long()
                one_hot = torch.zeros(shp_x)
                one_hot = one_hot.scatter_(1, target, 1)

        target_sdf = boudary_weight(one_hot.numpy(), net_out.shape)

        phi = torch.from_numpy(target_sdf)
        pred = net_out[:, 0:, ...].type(torch.float32)
        phi = phi[:, 0:, ...].type(torch.float32)

        multipled = torch.einsum("bcxy,bcxy->bcxy", pred, phi)
        bd_loss = multipled.sum(dim=3).sum(dim=2)
        bd_loss = bd_loss.mean()

        return bd_loss
        


class HDLoss(nn.Module):
    def __init__(self):
        super(HDLoss, self).__init__()

    def forward(self, net_out, target):
        net_out = torch.softmax(net_out, dim=1)
        shp_x = net_out.shape
        shp_y = target.shape

        if len(shp_x) != len(shp_y):
            target = target.view(shp_y[0], 1, *shp_y[1:])
        if all([i == j for i, j in zip(shp_x, shp_y)]):
            one_hot = target
        else:
            idx = target.long()
            one_hot = torch.zeros(shp_x)
            one_hot = one_hot.scatter_(1, idx, 1)

        distance_netout = self.distance_weight_netout(net_out.numpy())
        distance_target = self.distance_weight_target(one_hot.numpy())
        distance = distance_netout ** 2 + distance_target ** 2
        distance = torch.from_numpy(distance)

        pre_error = (net_out - one_hot) ** 2

        hd_loss = torch.einsum('bcxy, bcxy->bcxy', pre_error[:, 0:, ...], distance[:, 0:, ...])
        hd_loss = hd_loss.mean()

        return hd_loss


    def distance_weight_netout(self, net_out):
        shp_x = net_out.shape
        weight = np.zeros(shp_x)

        for i in range(shp_x[0]):   # batch
            for j in range(shp_x[1]):   # channel
                pos_mask = net_out[i][j]>0.5
                if pos_mask.any():
                    pos_is = distance_transform_edt(pos_mask)
                    weight[i][j] = pos_is

        return weight

    def distance_weight_target(self, target):
        shp_y = target.shape
        weight = np.zeros(shp_y)

        for i in range(shp_y[0]):
            for j in range(shp_y[1]):
                pos_mask = target[i][j].astype(np.bool_)
                if pos_mask.any():
                    pos_is = distance_transform_edt(pos_mask)
                    weight[i][j] = pos_is

        return weight


if __name__ == '__main__':
    img = torch.tensor(
        [[[[0.2, 0.2, 0.7, 0.7],
           [0.2, 0.2, 0.7, 0.7],
           [0.2, 0.2, 0.7, 0.7],
           [0.2, 0.2, 0.7, 0.7]],

          [[0.8, 0.8, 0.3, 0.3],
           [0.8, 0.8, 0.3, 0.3],
           [0.8, 0.8, 0.3, 0.3],
           [0.8, 0.8, 0.3, 0.3]]]]
    )
    target = torch.tensor([[[1, 1, 0, 0],
                            [1, 1, 0, 0],
                            [1, 1, 0, 0],
                            [1, 1, 0, 0]]])
    net = HDLoss()
    out = net(img, target)
    print(out)
  • 1
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 2
    评论
可以将这部分代码转化为PyTorch,转化后的代码如下: import torch import torch.nn as nn import torch.nn.functional as F def cross_entropy_loss(y_true, y_pred): # 计算交叉熵损失 cross_entropy = nn.CrossEntropyLoss()(y_pred, y_true) return cross_entropy def boundary_loss(y_true, y_pred): # 计算边界损失 boundary_filter = torch.tensor([[0, 1, 0], [1, -4, 1], [0, 1, 0]], dtype=torch.float32) boundary_filter = boundary_filter.view(1, 1, 3, 3) y_true_boundary = F.conv2d(y_true, boundary_filter, padding=1) y_pred_boundary = F.conv2d(y_pred, boundary_filter, padding=1) boundary_loss = F.mse_loss(y_true_boundary, y_pred_boundary) return boundary_loss def total_loss(y_true, y_pred): # 总损失函数 = 交叉熵损失 + 边界损失 return cross_entropy_loss(y_true, y_pred) + 0.5 * boundary_loss(y_true, y_pred) # 构建模型 class Model(nn.Module): def __init__(self): super(Model, self).__init__() self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=1, padding=1) self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1) self.pool = nn.MaxPool2d(kernel_size=2, stride=2) self.fc1 = nn.Linear(32*8*8, 128) self.fc2 = nn.Linear(128, 10) def forward(self, x): x = F.relu(self.conv1(x)) x = self.pool(F.relu(self.conv2(x))) x = x.view(-1, 32*8*8) x = F.relu(self.fc1(x)) x = self.fc2(x) return x model = Model() # 编译模型 optimizer = torch.optim.Adam(model.parameters(), lr=0.001) loss_fn = total_loss metrics = ['accuracy']

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

后天...

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值