_get_gt_mask、cat_mask、_get_other_mask

import torch

# 定义获取标签掩码的函数
def _get_gt_mask(logits, target):
    print("原始 logits:\n", logits)
    print("目标 target:\n", target)
    
    # 将 target 拉平为一维张量
    target = target.reshape(-1)
    print("拉平后的 target:\n", target)
    
    # 创建一个和 logits 大小相同的全零张量,然后根据 target 将对应的类别位置设置为1
    mask = torch.zeros_like(logits).scatter_(1, target.unsqueeze(1), 1).bool()
    print("生成的标签掩码 mask:\n", mask)
    
    # 返回根据 target 设置的标签掩码
    return mask

# 定义组合掩码的函数
def cat_mask(t, mask1, mask2):
    print("输入张量 t:\n", t)
    print("标签掩码 mask1:\n", mask1)
    print("非标签掩码 mask2:\n", mask2)
    
    # 计算 mask1 对应的 t 值,sum(dim=1) 表示在类别维度上进行求和
    t1 = (t * mask1).sum(dim=1, keepdims=True)
    print("标签类别的加权和 t1:\n", t1)
    
    # 计算 mask2 对应的 t 值
    t2 = (t * mask2).sum(1, keepdims=True)
    print("非标签类别的加权和 t2:\n", t2)
    
    # 将两个值拼接成新的张量
    rt = torch.cat([t1, t2], dim=1)
    print("拼接后的结果 rt:\n", rt)
    return rt

# 示例:假设有3个样本和5个类别的logits
logits = torch.tensor([[2.0, 1.0, 0.1, 3.0, 0.5],
                       [1.0, 3.0, 2.5, 0.5, 0.3],
                       [0.5, 2.2, 1.1, 4.0, 1.5]])

# 对应的标签 target
target = torch.tensor([3, 1, 4])  # 每个样本的正确类别是3, 1, 4

# 获取标签掩码
gt_mask = _get_gt_mask(logits, target)

# 获取非标签掩码
def _get_other_mask(logits, target):
    print("原始 logits:\n", logits)
    print("目标 target:\n", target)
    
    # 将 target 拉平为一维张量
    target = target.reshape(-1)
    print("拉平后的 target:\n", target)
    
    # 创建一个和 logits 大小相同的全1张量,然后根据 target 将对应的类别位置设置为0
    mask = torch.ones_like(logits).scatter_(1, target.unsqueeze(1), 0).bool()
    print("生成的非标签掩码 mask:\n", mask)
    
    return mask

other_mask = _get_other_mask(logits, target)

# 假设有某些 softmax 结果
t = torch.softmax(logits, dim=1)
print("Softmax 后的 logits (概率值):\n", t)

# 使用标签掩码和非标签掩码进行组合
combined = cat_mask(t, gt_mask, other_mask)
print("最终组合后的结果:\n", combined)

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值