-
自定义HingeLoss
class MyHingeLoss(torch.nn.Module):
# 不要忘记继承Module
def __init__(self):
super(MyHingeLoss, self).__init__()
def forward(self, output, target):
"""output和target都是1-D张量,换句话说,每个样例的返回是一个标量.
"""
hinge_loss = 1 - torch.mul(output, target)
hinge_loss[hinge_loss < 0] = 0
# 不要忘记返回scalar
return torch.mean(hinge_loss)
参考:https://zhuanlan.zhihu.com/p/80827719
-
自定义MSE损失
class My_loss(torch.nn.Module):
def __init__(self):
super(My_loss, self).__init__()
def forward(self, x, y):
return torch.mean(torch.pow((x-y), 2))