今天在做一个简单的线性回归问题:
import torch
import torch.nn as nn
learning_rate = 0.01
# 1.准备数据
x = torch.rand([500, 1])
y_true = x * 3 + 0.8
# 2.通过模型计算y_predict
w = torch.rand([1, 1], requires_grad=True)
b = torch.tensor(0, requires_grad=True, dtype=torch.float32)
for i in range(500):
y_predict = torch.matmul(x, w) + b
loss = (y_predict - y_true).pow(2).mean()
if w.grad is not None:
w.grad.data.zero_()
if b.grad is not None:
b.grad.data.zero_()
output.backward() # 反向传播
w.data = w.data - learning_rate * w.grad
b.data = b.data - learning_rate * b.grad
print("w, b, loss", w.item(), b.item(), loss.item())
一切都好好的,但是我想自己尝试着改一下代码,把这个求损失函数的方法改一下,于是,就改成了
loss = nn.BCELoss()
output = loss(y_predict, y_true)
可是改完损失就变成了负数,后来在同学的帮助下,我才知道了
(
y
_
p
r
e
d
i
c
t
−
y
_
t
r
u
e
)
2
(y\_predict - y\_true)^{2}
(y_predict−y_true)2是均方误差。
我使用的BCELoss属于交叉熵损失,一般用来求二分类问题的。这两种损失函数有天壤之别,我求线性回归需要判断y预测值和y真实值的差别是多少,而且不一定只有两个参数,还和其他参数有关。
常规分类网络最后的softmax层如下图所示,一共有K类,令网络的输出为[y^1,…, y^K],对应每个类别的概率,令label为 [y1,…,yK]。对某个属于p类的样本,其label中yp=1,y1,…,yp−1,yp+1,…,yK均为0。
交叉熵(Cross entropy)损失为:
L
=
−
(
y
1
log
y
^
1
+
⋯
+
y
K
log
y
^
K
)
=
−
y
p
log
y
^
p
=
−
log
y
^
p
\begin{aligned}L &= - (y_1 \log \hat{y}_1 + \dots + y_K \log \hat{y}_K) \\&= -y_p \log \hat{y}_p \\ &= - \log \hat{y}_p\end{aligned}
L=−(y1logy^1+⋯+yKlogy^K)=−yplogy^p=−logy^p
均方误差损失(mean squared error,MSE)为:
L
=
(
y
1
−
y
^
1
)
2
+
⋯
+
(
y
K
−
y
^
K
)
2
=
(
1
−
y
^
p
)
2
+
(
y
^
1
2
+
⋯
+
y
^
p
−
1
2
+
y
^
p
+
1
2
+
⋯
+
y
^
K
2
)
\begin{aligned}L &= (y_1 - \hat{y}_1)^2 + \dots + (y_K - \hat{y}_K)^2 \\&= (1 - \hat{y}_p)^2 + (\hat{y}_1^2 + \dots + \hat{y}_{p-1}^2 + \hat{y}_{p+1}^2 + \dots + \hat{y}_K^2)\end{aligned}
L=(y1−y^1)2+⋯+(yK−y^K)2=(1−y^p)2+(y^12+⋯+y^p−12+y^p+12+⋯+y^K2)
则m个样本的损失为:
ℓ
=
1
m
∑
i
=
1
m
L
i
\ell = \frac{1}{m} \sum_{i=1}^m L_i
ℓ=m1i=1∑mLi
可以看出交叉熵损失只和标签有关,yhat越接近1越好,或者说交叉熵是为了解决二分类能够结果趋近正确答案的问题。但是均方误差会受到很多因素的影响。对于我写的这个简单线性回归问题,关心的是y_predict和y_true的接近程度,而不是接近1的程度,所以应当使用均方误差来求误差。交叉熵的损失函数只和分类正确的预测结果有关系,而MSE的损失函数还和错误的分类有关系,该分类函数除了让正确的分类尽量变大,还会让错误的分类变得平均,但实际在分类问题中这个调整是没有必要的。但是对于回归问题来说,这样的考虑就显得很重要了。所以,回归问题熵使用交叉上并不合适。另外,平方损失函数表示数据服从正态分布,但是分类问题不服从正态分布,比如二分类服从伯努利分布。
参考资料:https://www.cnblogs.com/shine-lee/p/12032066.html
https://zhuanlan.zhihu.com/p/159063829