import torch
# 定义获取标签掩码的函数
def _get_gt_mask(logits, target):
print("原始 logits:\n", logits)
print("目标 target:\n", target)
# 将 target 拉平为一维张量
target = target.reshape(-1)
print("拉平后的 target:\n", target)
# 创建一个和 logits 大小相同的全零张量,然后根据 target 将对应的类别位置设置为1
mask = torch.zeros_like(logits).scatter_(1, target.unsqueeze(1), 1).bool()
print("生成的标签掩码 mask:\n", mask)
# 返回根据 target 设置的标签掩码
return mask
# 定义组合掩码的函数
def cat_mask(t, mask1, mask2):
print("输入张量 t:\n", t)
print("标签掩码 mask1:\n", mask1)
print("非标签掩码 mask2:\n", mask2)
# 计算 mask1 对应的 t 值,sum(dim=1) 表示在类别维度上进行求和
t1 = (t * mask1).sum(dim=1, keepdims=True)
print("标签类别的加权和 t1:\n", t1)
# 计算 mask2 对应的 t 值
t2 = (t * mask2).sum(1, keepdims=True)
print("非标签类别的加权和 t2:\n", t2)
# 将两个值拼接成新的张量
rt = torch.cat([t1, t2], dim=1)
print("拼接后的结果 rt:\n", rt)
return rt
# 示例:假设有3个样本和5个类别的logits
logits = torch.tensor([[2.0, 1.0, 0.1, 3.0, 0.5],
[1.0, 3.0, 2.5, 0.5, 0.3],
[0.5, 2.2, 1.1, 4.0, 1.5]])
# 对应的标签 target
target = torch.tensor([3, 1, 4]) # 每个样本的正确类别是3, 1, 4
# 获取标签掩码
gt_mask = _get_gt_mask(logits, target)
# 获取非标签掩码
def _get_other_mask(logits, target):
print("原始 logits:\n", logits)
print("目标 target:\n", target)
# 将 target 拉平为一维张量
target = target.reshape(-1)
print("拉平后的 target:\n", target)
# 创建一个和 logits 大小相同的全1张量,然后根据 target 将对应的类别位置设置为0
mask = torch.ones_like(logits).scatter_(1, target.unsqueeze(1), 0).bool()
print("生成的非标签掩码 mask:\n", mask)
return mask
other_mask = _get_other_mask(logits, target)
# 假设有某些 softmax 结果
t = torch.softmax(logits, dim=1)
print("Softmax 后的 logits (概率值):\n", t)
# 使用标签掩码和非标签掩码进行组合
combined = cat_mask(t, gt_mask, other_mask)
print("最终组合后的结果:\n", combined)