import torch
import torch.nn as nn
from losses_pytorch.ND_Crossentropy import CrossEntropy, TopkLoss
from scipy.ndimage import distance_transform_edt
import numpy as np
from skimage import segmentation as skimage_seg
def softamx_helper(x):
rpt = [1 for i in range(len(x.shape))]
rpt[1] = x.shape[1]
x_max = x.max(1, keepdim=True)[0].repeat(*rpt)
e_x = torch.exp(x - x_max)
softmax = e_x / e_x.sum(dim=1, keepdim=True).repeat(*rpt)
return softmax
def sum_tensor(inp, axes, keepdim=False):
if keepdim:
for ax in axes:
inp = inp.sum(int(ax), keepdim=True)
else:
for ax in axes:
inp = inp.sum(int(ax), keepdim=False)
return inp
def tp_tn_fp_fn(net_out, target, axes=None, mask=None, square=False):
num_class = net_out.shape[1]
if axes is None:
axes = tuple(range(2, len(net_out.size())))
shp_x = net_out.shape
shp_y = target.shape
with torch.no_grad():
if len(shp_x) != len(shp_y):
target = target.view(shp_y[0], 1, *shp_y[1:])
if all([i == j for i, j in zip(shp_x, shp_y)]):
one_hot = target
else:
idx = target.long()
one_hot = torch.zeros(shp_x)
one_hot = one_hot.scatter_(1, idx, 1)
tp = net_out * one_hot
tn = (1 - net_out) * (1 - one_hot)
fp = net_out * (1 - one_hot)
fn = (1 - net_out) * one_hot
tp = sum_tensor(tp, axes, keepdim=True).view(-1, num_class)
tn = sum_tensor(tn, axes, keepdim=True).view(-1, num_class)
fp = sum_tensor(fp, axes, keepdim=True).view(-1, num_class)
fn = sum_tensor(fn, axes, keepdim=True).view(-1, num_class)
return tp, tn, fp, fn
def boudary_weight(target, out_shape):
"""
compute the signed distance map of binary mask
input: segmentation, shape = (batch_size, x, y, z)
output: the Signed Distance Map (SDM)
sdf(x) = 0; x in segmentation boundary
-inf|x-y|; x in segmentation
+inf|x-y|; x out of segmentation
"""
target = target.astype(np.uint8)
weight = np.zeros(out_shape)
for b in range(out_shape[0]): # batch size
for c in range(0, out_shape[1]): # channel
posmask = target[b][c].astype(np.bool)
if posmask.any():
negmask = ~posmask
posdis = distance_transform_edt(posmask)
negdis = distance_transform_edt(negmask)
boundary = skimage_seg.find_boundaries(posmask, mode='inner').astype(np.uint8)
get = negdis - posdis
get[boundary == 1] = 0
weight[b][c] = get
return weight
class BDLoss(nn.Module):
def __init__(self):
super(BDLoss, self).__init__()
# self.do_bg = do_bg
def forward(self, net_out, target):
"""
net_out: (batch_size, class, x,y,z)
target: ground truth, shape: (batch_size, 1, x,y,z)
bound_weight: precomputed distance map, shape (batch_size, class, x,y,z)
"""
net_out = torch.softmax(net_out, dim=1)
shp_x = net_out.shape
shp_y = target.shape
with torch.no_grad():
if len(shp_x) != len(shp_y):
target = target.view((shp_y[0], 1, *shp_y[1:]))
if all([i == j for i, j in zip(shp_x, shp_y)]):
one_hot = target
else:
target = target.long()
one_hot = torch.zeros(shp_x)
one_hot = one_hot.scatter_(1, target, 1)
target_sdf = boudary_weight(one_hot.numpy(), net_out.shape)
phi = torch.from_numpy(target_sdf)
pred = net_out[:, 0:, ...].type(torch.float32)
phi = phi[:, 0:, ...].type(torch.float32)
multipled = torch.einsum("bcxy,bcxy->bcxy", pred, phi)
bd_loss = multipled.sum(dim=3).sum(dim=2)
bd_loss = bd_loss.mean()
return bd_loss
class HDLoss(nn.Module):
def __init__(self):
super(HDLoss, self).__init__()
def forward(self, net_out, target):
net_out = torch.softmax(net_out, dim=1)
shp_x = net_out.shape
shp_y = target.shape
if len(shp_x) != len(shp_y):
target = target.view(shp_y[0], 1, *shp_y[1:])
if all([i == j for i, j in zip(shp_x, shp_y)]):
one_hot = target
else:
idx = target.long()
one_hot = torch.zeros(shp_x)
one_hot = one_hot.scatter_(1, idx, 1)
distance_netout = self.distance_weight_netout(net_out.numpy())
distance_target = self.distance_weight_target(one_hot.numpy())
distance = distance_netout ** 2 + distance_target ** 2
distance = torch.from_numpy(distance)
pre_error = (net_out - one_hot) ** 2
hd_loss = torch.einsum('bcxy, bcxy->bcxy', pre_error[:, 0:, ...], distance[:, 0:, ...])
hd_loss = hd_loss.mean()
return hd_loss
def distance_weight_netout(self, net_out):
shp_x = net_out.shape
weight = np.zeros(shp_x)
for i in range(shp_x[0]): # batch
for j in range(shp_x[1]): # channel
pos_mask = net_out[i][j]>0.5
if pos_mask.any():
pos_is = distance_transform_edt(pos_mask)
weight[i][j] = pos_is
return weight
def distance_weight_target(self, target):
shp_y = target.shape
weight = np.zeros(shp_y)
for i in range(shp_y[0]):
for j in range(shp_y[1]):
pos_mask = target[i][j].astype(np.bool_)
if pos_mask.any():
pos_is = distance_transform_edt(pos_mask)
weight[i][j] = pos_is
return weight
if __name__ == '__main__':
img = torch.tensor(
[[[[0.2, 0.2, 0.7, 0.7],
[0.2, 0.2, 0.7, 0.7],
[0.2, 0.2, 0.7, 0.7],
[0.2, 0.2, 0.7, 0.7]],
[[0.8, 0.8, 0.3, 0.3],
[0.8, 0.8, 0.3, 0.3],
[0.8, 0.8, 0.3, 0.3],
[0.8, 0.8, 0.3, 0.3]]]]
)
target = torch.tensor([[[1, 1, 0, 0],
[1, 1, 0, 0],
[1, 1, 0, 0],
[1, 1, 0, 0]]])
net = HDLoss()
out = net(img, target)
print(out)
boundary_loss
最新推荐文章于 2024-04-30 15:11:08 发布