在loss.py文件里面新增了一个损失函数VFloss。该函数主要是用于替换Focal loss(两个损失函数并不能说谁好谁坏,各有自己的看法吧)
代码如下:
class VFloss(nn.Module):
def __init__(self,loss_fcn,gamma=1.5,alpha=0.25):
super(VFloss, self).__init__()
#传递nn.BCEWithLogitsLoss()损失函数
self.loss_fcn = loss_fcn
self.gamma = gamma
self.alpha = alpha
self.reduction = loss_fcn.reduction
self.loss_fcn.reduction = 'none' #这里的‘none’原本是‘mean’,但是上文的loss函数这里都用了‘none’,且用‘mean’会报错
def forward(self, pred, true):
loss = self.loss_fcn(pred, true)
pred_prob = torch.sigmoid(pred)
focal_weight = true * (true > 0.0).float() + self.alpha * (pred_prob - true).abs().pow(self.gamma) * (true<=0.0).float()
loss *=focal_weight
if self.reduction == 'mean':
return loss.mean()
elif self.re