adkd

import torch
import torch.nn.functional as F

# 定义 klloss_v2 函数
def klloss_v2(logits_t, input, target, label, beta):
    # nckd_loss=klloss_v2(origin_logits_t,logits_student,logits_teacher,target,alaph)*(temperature**2)
    print("输入参数:")
    print("logits_t (教师模型的logits):\n", logits_t)
    print("input (学生模型的logits):\n", input)
    print("target (教师模型的target):\n", target)
    print("label (真实标签):\n", label)
    
    # 计算 log_softmax 和 softmax
    log_input = F.log_softmax(input, dim=1)
    log_target = F.log_softmax(target, dim=1)
    target = F.softmax(target, dim=1)

    print("\nlog_input (学生模型的 log_softmax):\n", log_input)
    print("log_target (教师模型的 log_softmax):\n", log_target)
    print("target (教师模型的 softmax):\n", target)

    # 计算 output = target * (log_target - log_input)
    output = target * (log_target - log_input)
    print("\noutput (softmax 后的概率差异):\n", output)

    # 计算每个样本真实标签类别与其他类别之间的距离矩阵
    matrix = []
    for (x, y) in zip(label.cpu(), logits_t.detach()):
        print(f"\n处理样本 {x}:")
        print("教师模型的 logits (y):\n", y)
        diff = y[x] - y
        print(f"与真实类别 {x} 的差异 (y[x] - y):\n", diff)
        matrix.append(diff)

    # 合并 matrix,并 reshape 为 [batch_size, num_classes] 的形状
    matrix = torch.cat(matrix)
    matrix = matrix.reshape(-1, input.size(1))
    print("\n组合后的差异矩阵 (matrix):\n", matrix)

    # 对矩阵进行缩放和偏移处理
    matrix = matrix / beta
    matrix = matrix + 8.0
    print("\n缩放和偏移后的矩阵 (matrix):\n", matrix)

    # 计算损失
    loss = (matrix * output).sum()
    loss = loss / input.shape[0]
    print("\n最终计算的损失 (loss):\n", loss)
    
    return loss

# 示例数据
logits_t = torch.tensor([[2.5, 1.2, 0.8, 3.0, 2.0], 
                         [1.0, 2.0, 3.5, 0.5, 1.8]], dtype=torch.float32)  # 教师模型 logits
input = torch.tensor([[2.0, 1.0, 0.5, 2.5, 1.5], 
                      [0.5, 2.0, 3.0, 0.2, 1.5]], dtype=torch.float32)  # 学生模型 logits
target = torch.tensor([[2.5, 1.2, 0.8, 3.0, 2.0], 
                       [1.0, 2.0, 3.5, 0.5, 1.8]], dtype=torch.float32)  # 教师模型的 target (可理解为软标签)
label = torch.tensor([3, 2])  # 真实类别
beta = 1.5  # 调节参数

# 调用 klloss_v2 函数
loss = klloss_v2(logits_t, input, target, label, beta)

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值