假设某分类网络有10个类,输出的得分为 logits,我们仅考虑1batch_size的情况,即 logits 可以理解为 float数组 logits[10].
logits经过softmax对应得到的标签值 preds 的公式为:
如果加上温度系数,则公式写为:
是不是看起来不直观,我们来画个图看下加温度系数temp前后preds值得变化。
注:为了看起来更直观 logits[10]我取了标准正态分布,设temp=0.01。
图左为没有温度系数temp(也可以理解为temp=1)的preds数组柱状图,图右为加了温度系数temp=0.01的preds柱状图。
取更小的温度系数temp,标签值preds会变得更加尖锐化。使问题更加趋向于非黑即白的情况。
以下是相应的python代码:
import random
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
import numpy as np
mixed_target = torch.ones((10))
mixed_target = mixed_target/10.0
mixed_temp = torch.ones((10))
mixed_temp = mixed_temp/100.0
logits = torch.randn((10))
logits = logits/100.0
preds = F.softmax(logits, dim=0)
preds_n = preds.numpy()
preds = torch.clamp(preds, min=1e-8)
preds = torch.log(preds)
loss = -torch.mean(torch.sum(mixed_target * preds, dim=0))
preds_t = F.softmax(logits / mixed_temp, dim=0)
preds_nt = preds_t.numpy()
preds_t = torch.clamp(preds_t, min=1e-8)
preds_t = torch.log(preds_t)
loss_t = -torch.mean(torch.sum(mixed_target * preds_t, dim=0))
X_set = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
p1 = plt.bar(X_set, np.around(preds_n, 3), width=0.35, label='value')
plt.bar_label(p1, label_type='edge')
plt.title('The distribution of softmax')
plt.show()
p2 = plt.bar(X_set, np.around(preds_nt, 3), width=0.35, label='value')
plt.bar_label(p2, label_type='edge')
plt.title('The distribution of softmax with tempreture')
plt.show()
p3 = plt.bar(['temp', 'no_temp'], [loss_t, loss], width=0.35, label='value')
plt.bar_label(p3, label_type='edge')
plt.title('The distribution of softmax with tempreture')
plt.show()