import torch
import torch.nn.functional as F
# 定义 klloss_v2 函数
def klloss_v2(logits_t, input, target, label, beta):
# nckd_loss=klloss_v2(origin_logits_t,logits_student,logits_teacher,target,alaph)*(temperature**2)
print("输入参数:")
print("logits_t (教师模型的logits):\n", logits_t)
print("input (学生模型的logits):\n", input)
print("target (教师模型的target):\n", target)
print("label (真实标签):\n", label)
# 计算 log_softmax 和 softmax
log_input = F.log_softmax(input, dim=1)
log_target = F.log_softmax(target, dim=1)
target = F.softmax(target, dim=1)
print("\nlog_input (学生模型的 log_softmax):\n", log_input)
print("log_target (教师模型的 log_softmax):\n", log_target)
print("target (教师模型的 softmax):\n", target)
# 计算 output = target * (log_target - log_input)
output = target * (log_target - log_input)
print("\noutput (softmax 后的概率差异):\n", output)
# 计算每个样本真实标签类别与其他类别之间的距离矩阵
matrix = []
for (x, y) in zip(label.cpu(), logits_t.detach()):
print(f"\n处理样本 {x}:")
print("教师模型的 logits (y):\n", y)
diff = y[x] - y
print(f"与真实类别 {x} 的差异 (y[x] - y):\n", diff)
matrix.append(diff)
# 合并 matrix,并 reshape 为 [batch_size, num_classes] 的形状
matrix = torch.cat(matrix)
matrix = matrix.reshape(-1, input.size(1))
print("\n组合后的差异矩阵 (matrix):\n", matrix)
# 对矩阵进行缩放和偏移处理
matrix = matrix / beta
matrix = matrix + 8.0
print("\n缩放和偏移后的矩阵 (matrix):\n", matrix)
# 计算损失
loss = (matrix * output).sum()
loss = loss / input.shape[0]
print("\n最终计算的损失 (loss):\n", loss)
return loss
# 示例数据
logits_t = torch.tensor([[2.5, 1.2, 0.8, 3.0, 2.0],
[1.0, 2.0, 3.5, 0.5, 1.8]], dtype=torch.float32) # 教师模型 logits
input = torch.tensor([[2.0, 1.0, 0.5, 2.5, 1.5],
[0.5, 2.0, 3.0, 0.2, 1.5]], dtype=torch.float32) # 学生模型 logits
target = torch.tensor([[2.5, 1.2, 0.8, 3.0, 2.0],
[1.0, 2.0, 3.5, 0.5, 1.8]], dtype=torch.float32) # 教师模型的 target (可理解为软标签)
label = torch.tensor([3, 2]) # 真实类别
beta = 1.5 # 调节参数
# 调用 klloss_v2 函数
loss = klloss_v2(logits_t, input, target, label, beta)