如何构建置信度掩码(代码)

import torch
import torch.nn.functional as F

# 假设 preds 是模型输出的预测结果,尺寸为 [2, 256, 256]
# 假设 labels 是真实标签图像,尺寸为 [256, 256],值为 0 或 1

# 假设的模型输出和标签(随机生成,实际使用时替换为真实数据)
preds = torch.rand(2, 256, 256)
labels = torch.randint(0, 2, (256, 256))

# 应用 softmax 来获取概率分布
probs = F.softmax(preds, dim=0)

# 选择每个像素点概率最高的类别
confidences, predictions = torch.max(probs, dim=0)

# 置信度阈值
tau = 0.6

# 生成置信度掩码
confidence_mask = confidences >= tau
filtered_labels = labels[confidence_mask]

confidence_mask = confidence_mask.unsqueeze(0).repeat(2, 1, 1)# 使用置信度掩码过滤掉不满足置信度要求的预测
filtered_probs = probs[confidence_mask]

data_part1 = filtered_probs[:int(filtered_probs.size(0)/2)]
data_part2 = filtered_probs[int(filtered_probs.size(0)/2):]

# 然后将这两部分堆叠起来形成所需的形状
reshaped_data = torch.stack((data_part1, data_part2), dim=1)
# 计算 CE 损失,只考虑置信度高的像素点

print(reshaped_data)
print(filtered_labels)

print(reshaped_data.size())
print(filtered_labels.size())
loss = F.cross_entropy(reshaped_data, filtered_labels)

print(loss)

对于CE损失函数,要求的尺寸预测图像要比GT多一维,所以在代码中特别地构建了reshape_data ,这样就可以计算损失了。

代码写得不是很规范,欢迎指正

  • 2
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值