- 在【Pytorch实战(三)】线性回归中介绍了线性回归。线性判决与线性回归是类似的,区别在于线性回归的“标签”通常是较为连续的取值,而线性判决的“标签”往往只有几个可能的取值,例如【0,1】。
- 在线性回归中,经常使用MSE损失
torch.nn.MSELoss()
、L1损失torch.nn.L1Loss
和SmoothL1损失torch.nn.SmoothL1Loss
等。相应地,在线性判决中也需使用合适的损失。对于二元判决常用BCE(Binary Cross Entropy)损失;对于多元则常用Cross Entropy损失。
- 以二元判决为例,计算BCE损失需要获得一个预测概率,这需要使用逻辑回归算法(对应于
torch.sigmoid()
函数);若拓展至多元则对应于softmax函数。具体公式不再赘述,可参照结合实例理解pytorch中交叉熵损失函数进行理解。
- 实际实现中,二元的
torch.nn.BCELoss()
和torch.nn.BCEWithLogitsLoss()
又有所区别,前者需传入预测概率和真实标签,而后者则需传入W与X线性运算的结果Z和真实标签。多元的则使用torch.nn.CrossEntropyLoss
即可。
- 以下为例程:
import torch
import torch.nn
import torch.optim
x = torch.tensor([[1., 2<