关于Pytorch中的MSELoss的坑

在pytorch中,经常使用nn.MSELoss作为损失函数,例如

loss=nn.MSELoss()
input=torch.randn(3,5,requires_grad=True)
target=torch.randn(3,5)
error=loss(input,target)
error.backward()

这个地方有一个巨坑,就是一定要小心input和target的位置,说的更具体一些,target一定需要是一个不能被训练更新的、requires_grad=False的值,否则会报错!!!

 

另外,关于MSELoss的设定

若设定loss=torch.nn.MSELoss(reduction='mean'),最终输出值是(target-input)每个元素数字平方和除以width x height,也就是在batch和特征维度上都做了平均。如果只想在batch上做平均,则可以写成这个样子:

#需要注意的是,这里的input和target是mini-batch的形式
loss=torch.nn.MSELoss(reduction='sum')
loss=loss(input,target)/target.size(0)

 

评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值