'int' object has no attribute 'backward'报错 使用Pytorch编写 Hinge loss函数

在编写SVM中的Hinge loss函数的时候报错“'int' object has no attribute 'backward'”

for epoch in range(50):
    for batch in dataloader:
        opt.zero_grad()
        output=hinge_loss(svm(batch[0],w,b),batch[1]) 
        output.backward()
        opt.step()
    draw_margin(w, b, camera)

报错的原因是output,也就是损失函数这里输出了int值。但是在实验过程中,梯度确实是下下降了。只是总是在下降过程中出现了这种报错。

 

经过排错,发现了hinge_loss函数中出现了问题

def hinge_loss(y_pred,y_true):
    return (0 if 0>(1-y_pred*y_true).mean() else (1-y_pred*y_true).mean()) 

 

注意,此时如果 0>(1-y_pred*y_true).mean() 这个条件成立,函数会返回0。0本身是个int类型的数。这就是为什么有时候梯度可以下降。有时候不可以。

 

将原函数修改为:

def hinge_loss(y_pred,y_true):
    #这里注意,返回0的时候,0的类型需要是tensor类型,并且需要梯度。
    return (torch.tensor(0.0,requires_grad=True) if 0>(1-y_pred*y_true).mean() else (1-y_pred*y_true).mean())  

 

  • 2
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值