import torch
from torch.nn import BCEWithLogitsLoss
def vprint(*args):
verbose = False
if verbose:
print(*args)
def compute_compound_loss(criterion_dict: dict, raw_network_outputs: torch.Tensor, label: torch.Tensor,
blob_loss_mode=True, masked=True):
"""
这通过循环标准dict计算复合损失!
"""
# vprint("outputs:", outputs)
losses = []
for entry in criterion_dict.values():
vprint("loss name:", entry["name"])
criterion = entry["loss"]
weight = entry["weight"]
sigmoid = entry["sigmoid"]
if blob_loss_mode is False:
vprint("computing main loss!")
raw_network_output, _ = torch.max(raw_network_outputs, dim=1)
if sigmoid is True:
sigmoid_network_outputs = torch.sigmoid(raw_network_output)
individual_loss = criterion(sigmoid_network_outputs, label)
else:
individual_loss = criterion(raw_network_output, label.float())
elif blob_loss_mode is True:
vprint("computing blob loss!")
if masked is True: # this is the default blob loss
if sigmoid is True:
sigmoid_network_outputs = torch.sigmoid(raw_network_outputs)
individual_loss = compute_blob_loss_multi(criterion, sigmoid_network_outputs, label)
else:
individual_loss = compute_blob_loss_multi(criterion, raw_network_outputs, label)
elif masked is False: # without masking for ablation study
if sigmoid is True:
sigmoid_network_outputs = torch.sigmoid(raw_network_outputs)
individual_loss = compute_no_masking_multi(criterion, sigmoid_network_outputs, label)
else:
individual_loss = compute_no_masking_multi(criterion, raw_network_outputs, label)
weighted_loss = individual_loss * weight
losses.append(weighted_loss)
loss = sum(losses)
return loss
def compute_blob_loss_multi(criterion, network_outputs: torch.Tensor, multi_label: torch.Tensor):
"""
1、循环我们批次中的元素
2、循环通过每个元素的blob计算损耗,并除以blob得到元素损耗
2.1我们需要与BCE一起考虑乙状结肠和非乙状结肠
3、除以批次长度,得到正确的后支柱批次损失
"""
batch_length = multi_label.shape[0]
element_blob_loss = []
# loop over elements
for element in range(batch_length):
if element < batch_length:
end_index = element + 1
elif element == batch_length:
end_index = None
element_label = multi_label[element:end_index, ...]
element_output = network_outputs[element:end_index, ...]
# loop through labels 循环浏览标签
unique_labels = torch.unique(element_label)
label_loss = []
for ula in unique_labels:
if ula == 0:
vprint("ula is 0 we do nothing")
else:
# first we need one hot labels
label_mask = element_label > 0
label_mask = ~label_mask
label_mask[element_label == ula] = 1
the_label = element_label == ula
the_label_int = the_label.int()
masked_output = element_output * label_mask
try:
blob_loss = criterion(masked_output, the_label_int)
except:
blob_loss = criterion(masked_output, the_label.float())
label_loss.append(blob_loss)
if len(label_loss) != 0:
mean_label_loss = sum(label_loss) / len(label_loss)
element_blob_loss.append(mean_label_loss)
# compute mean
mean_element_blob_loss = 0
if not len(element_blob_loss) == 0:
mean_element_blob_loss = sum(element_blob_loss) / len(element_blob_loss)
return mean_element_blob_loss
def compute_no_masking_multi(criterion, network_outputs: torch.Tensor, multi_label: torch.Tensor):
"""
1、循环我们批次中的元素
2、循环通过每个元素的blob计算损耗,并除以blob得到元素损耗
2.1我们需要与BCE一起考虑乙状结肠和非乙状结肠
3、除以批次长度,得到正确的后支柱批次损失
"""
batch_length = multi_label.shape[0]
element_blob_loss = []
# loop over elements
for element in range(batch_length):
if element < batch_length:
end_index = element + 1
elif element == batch_length:
end_index = None
element_label = multi_label[element:end_index, ...]
element_output = network_outputs[element:end_index, ...]
# loop through labels
unique_labels = torch.unique(element_label)
label_loss = []
for ula in unique_labels:
if ula == 0:
vprint("ula is 0 we do nothing")
else:
# first we need one hot labels
the_label = element_label == ula
the_label_int = the_label.int()
try:
blob_loss = criterion(element_output, the_label_int)
except:
blob_loss = criterion(element_output, the_label.float())
label_loss.append(blob_loss)
# compute mean
if not len(label_loss) == 0:
mean_label_loss = sum(label_loss) / len(label_loss)
element_blob_loss.append(mean_label_loss)
# compute mean
mean_element_blob_loss = 0
if not len(element_blob_loss) == 0:
mean_element_blob_loss = sum(element_blob_loss) / len(element_blob_loss)
return mean_element_blob_loss
def compute_loss(blob_loss_dict: dict, criterion_dict: dict, blob_criterion_dict: dict,
raw_network_outputs: torch.Tensor, binary_label: torch.Tensor, multi_label: torch.Tensor):
"""
此函数用于计算总损失。
它有一个全局主损失和blob损失项,对于每个连接的组件分别计算。
binary_label是全局零件的二进制标签。
multi_label为每个连接的组件提供了单独的整数标签。
Example inputs should look like:
blob_loss_dict = {
"main_weight": 1,
"blob_weight": 0,
}
criterion_dict = {
"bce": {
"name": "bce",
"loss": BCEWithLogitsLoss(reduction="mean"),
"weight": 1.0,
"sigmoid": False,
},
"dice": {
"name": "dice",
"loss": DiceLoss(
include_background=True,
to_onehot_y=False,
sigmoid=True,
softmax=False,
squared_pred=False,
),
"weight": 1.0,
"sigmoid": False,
},
}
blob_criterion_dict = {
"bce": {
"name": "bce",
"loss": BCEWithLogitsLoss(reduction="mean"),
"weight": 1.0,
"sigmoid": False,
},
"dice": {
"name": "dice",
"loss": DiceLoss(
include_background=True,
to_onehot_y=False,
sigmoid=True,
softmax=False,
squared_pred=False,
),
"weight": 1.0,
"sigmoid": False,
},
}
"""
main_weight = blob_loss_dict["main_weight"]
blob_weight = blob_loss_dict["blob_weight"]
# main loss
if main_weight > 0:
main_loss = compute_compound_loss(criterion_dict=criterion_dict, raw_network_outputs=raw_network_outputs,
label=binary_label, blob_loss_mode=False)
if blob_weight > 0:
blob_loss = compute_compound_loss(criterion_dict=blob_criterion_dict, raw_network_outputs=raw_network_outputs,
label=multi_label, blob_loss_mode=True)
# final loss
if blob_weight == 0 and main_weight > 0:
loss = main_loss
blob_loss = 0
elif main_weight == 0 and blob_weight > 0:
loss = blob_loss
main_loss = 0 # we set this to 0
elif main_weight > 0 and blob_weight > 0:
loss = main_loss * main_weight + blob_loss * blob_weight
else:
vprint("defaulting to equal weighted blob loss")
loss = main_loss + blob_loss
vprint("blob loss:", blob_loss)
vprint("main loss:", main_loss)
vprint("effective loss:", loss)
return loss, main_loss, blob_loss
def one_hot(net_out, target):
shp_x = net_out.shape
shp_y = target.shape
if len(shp_x) != len(shp_y):
target = target.view(target.shape[0], 1, *target.shape[1:])
if all([i == j for i, j in zip(net_out.shape, target.shape)]):
one_hot = target
else:
target = target.long()
one_hot = torch.zeros(net_out.shape)
if net_out.device.type == 'cuda':
one_hot.cuda(net_out.device.idex)
one_hot.scatter_(1, target, 1)
return one_hot
def get_tp_fp_fn(net_output, gt, axes=None, mask=None, square=False):
"""
copy from: https://github.com/MIC-DKFZ/nnUNet/blob/master/nnunet/training/loss_functions/dice_loss.py
net_output must be (b, c, x, y(, z)))
gt must be a label map (shape (b, 1, x, y(, z)) OR shape (b, x, y(, z))) or one hot encoding (b, c, x, y(, z))
if mask is provided it must have shape (b, 1, x, y(, z)))
:param net_output:
:param gt:
:param axes:
:param mask: mask must be 1 for valid pixels and 0 for invalid pixels
:param square: if True then fp, tp and fn will be squared before summation
:return:
"""
shp_x = net_output.shape
shp_y = gt.shape
with torch.no_grad():
if len(shp_x) != len(shp_y):
gt = gt.view((shp_y[0], 1, *shp_y[1:]))
if all([i == j for i, j in zip(net_output.shape, gt.shape)]):
# if this is the case then gt is probably already a one hot encoding
y_onehot = gt
else:
gt = gt.long()
y_onehot = torch.zeros(shp_x)
if net_output.device.type == "cuda":
y_onehot = y_onehot.cuda(net_output.device.index)
y_onehot.scatter_(1, gt, 1)
tp = net_output * y_onehot
tn = (1 - net_output) * (1 - y_onehot)
fp = net_output * (1 - y_onehot)
fn = (1 - net_output) * y_onehot
if mask is not None:
tp = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(tp, dim=1)), dim=1)
fp = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(fp, dim=1)), dim=1)
fn = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(fn, dim=1)), dim=1)
if square:
tp = tp ** 2
fp = fp ** 2
fn = fn ** 2
tp = torch.sum(tp)
tn = torch.sum(tn)
fp = torch.sum(fp)
fn = torch.sum(fn)
print("tp: {}\ntn: {}\nfp: {}\nfn: {}".format(tp, tn, fp, fn))
return tp, tn, fp, fn
def SoftDiceLoss(x, y, loss_mask=None):
tp, tn, fp, fn = get_tp_fp_fn(x, y, loss_mask)
dc = (2 * tp + 1e-7) / (2 * tp + fp + fn + 1e-7)
dc = dc.mean()
return 1 - dc
if __name__ == '__main__':
net_out = torch.tensor(
[[[[0.8, 0.8, 0.8, 0.1, 0.1],
[0.8, 0.8, 0.8, 0.1, 0.1],
[0.8, 0.8, 0.8, 0.1, 0.1],
[0.1, 0.1, 0.1, 0.2, 0.2],
[0.1, 0.1, 0.1, 0.2, 0.2]],
[[0.05, 0.05, 0.05, 0.8, 0.8],
[0.05, 0.05, 0.05, 0.8, 0.8],
[0.05, 0.05, 0.05, 0.8, 0.8],
[0.04, 0.04, 0.04, 0.05, 0.05],
[0.04, 0.04, 0.04, 0.05, 0.05]],
[[0.04, 0.04, 0.04, 0.03, 0.03],
[0.04, 0.04, 0.04, 0.03, 0.03],
[0.04, 0.04, 0.04, 0.03, 0.03],
[0.7, 0.7, 0.7, 0.15, 0.15],
[0.7, 0.7, 0.7, 0.15, 0.15]],
[[0.09, 0.09, 0.09, 0.07, 0.07],
[0.09, 0.09, 0.09, 0.07, 0.07],
[0.09, 0.09, 0.09, 0.07, 0.07],
[0.16, 0.16, 0.16, 0.6, 0.6],
[0.16, 0.16, 0.16, 0.6, 0.6]]]]
)
target = torch.tensor(
[[[1, 1, 1, 2, 2],
[1, 1, 1, 2, 2],
[1, 1, 1, 2, 2],
[3, 3, 3, 0, 0],
[3, 3, 3, 0, 0]]]
)
binary_label = torch.tensor(
[[[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1],
[1, 1, 1, 0, 0],
[1, 1, 1, 0, 0]]]
)
multi_label = one_hot(net_out, target)
blob_loss_dict = {
"main_weight": 1,
"blob_weight": 1,
}
criterion_dict = {
"bce": {
"name": "bce",
"loss": BCEWithLogitsLoss(reduction="mean"),
"weight": 1.0,
"sigmoid": False,
},
"dice": {
"name": "dice",
"loss": SoftDiceLoss,
"weight": 1.0,
"sigmoid": False,
},
}
blob_criterion_dict = {
"bce": {
"name": "bce",
"loss": BCEWithLogitsLoss(reduction="mean"),
"weight": 1.0,
"sigmoid": False,
},
"dice": {
"name": "dice",
"loss": SoftDiceLoss,
"weight": 1.0,
"sigmoid": False,
},
}
out = compute_loss(blob_loss_dict=blob_loss_dict, criterion_dict=criterion_dict, blob_criterion_dict=blob_criterion_dict,
raw_network_outputs=net_out, binary_label=binary_label, multi_label=multi_label)
print(out)
blob_loss
于 2022-06-06 21:11:52 首次发布