import torch
import torch.nn as nn
def ss_loss(net_out, target, smooth=1e-6):
# 计算标签真值之内的损失
molecule_1 = ((target - net_out) ** 2) * target
denominator_1 = target
loss_1 = (molecule_1.sum() + smooth) / (denominator_1.sum() + smooth)
# 计算真值之外的损失
molecule_2 = ((target - net_out) ** 2) * (1 - target)
denominator_2 = (1 - target)
loss_2 = (molecule_2.sum() + smooth) / (denominator_2.sum() + smooth)
return 1 * loss_1 + 1 * loss_2
class Mismatch_loss(nn.Module):
def __init__(self):
super(Mismatch_loss, self).__init__()
def forward(self, net_out, target, max_positiones):
losses = []
for j in range(net_out.shape[0]):
for i in range(0, net_out.shape[1]):
max_target = torch.max(target[j, i, ...])
max_position = torch.max(max_positiones[j, i, ...])
if max_target == 0 and max_position == 0:
continue
else:
loss = ss_loss(net_out[j, i, ...], target[j, i, ...])
losses.append(loss)
losses = torch.stack(losses)
return losses.mean()
if __name__ == '__main__':
net_out = torch.tensor(
[[[[0.3, 0.3, 0.3, 0.9, 0.3, 0.3, 0.3, 0.3],
[0.3, 0.9, 0.9, 0.9, 0.9, 0.3, 0.9, 0.9],
[0.3, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9],
[0.3, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9],
[0.3, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9],
[0.3, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9],
[0.3, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9],
[0.3, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9]]]]
)
# max_netout是根据net_out进行softmax得来的
max_netout = torch.tensor(
[[[[0, 0, 0, 1, 0, 0, 0, 0],
[0, 1, 1, 1, 1, 0, 1, 1],
[0, 1, 1, 1, 1, 1, 1, 1],
[0, 1, 1, 1, 1, 1, 1, 1],
[0, 1, 1, 1, 1, 1, 1, 1],
[0, 1, 1, 1, 1, 1, 1, 1],
[0, 1, 1, 1, 1, 1, 1, 1],
[0, 1, 1, 1, 1, 1, 1, 1]]]]
)
target = torch.tensor(
[[[[0, 0, 0, 0, 0, 0, 0, 0],
[0, 1, 1, 1, 1, 1, 1, 1],
[0, 1, 1, 1, 1, 1, 1, 1],
[0, 1, 1, 1, 1, 1, 1, 1],
[0, 1, 1, 1, 1, 1, 1, 1],
[0, 1, 1, 1, 1, 1, 1, 1],
[0, 1, 1, 1, 1, 1, 1, 1],
[0, 1, 1, 1, 1, 1, 1, 1]]]]
)
classes = Mismatch_loss()
loss = classes(net_out, target, max_netout)
print(loss)
ss_loss
于 2022-06-19 12:31:05 首次发布
该博客探讨了一种名为`Mismatch_loss`的自定义损失函数,它结合了`ss_loss`来计算真值内外的损失。`ss_loss`通过计算目标值与网络输出之间的平方差并考虑平滑项来确定损失。`Mismatch_loss`在每个位置上对网络输出和目标值进行比较,并只计算非零最大目标和位置的损失。在提供的示例中,该损失函数被应用于具有多个通道和位置的张量,并在主函数中用具体的数据进行了演示。
摘要由CSDN通过智能技术生成