给定损失函数的输入y,pred,shape均为bxc。
若设定loss_fn = torch.nn.MSELoss(reduction='mean'),最终的输出值其实是(y - pred)每个元素数字的平方之和除以(bxc),也就是在batch和特征维度上都取了平均。
如果只想在batch上做平均,可以这样写:
loss_fn = torch.nn.MSELoss(reduction='sum')
loss = loss_fn(pred, y) / pred.size(0)
给定损失函数的输入y,pred,shape均为bxc。
若设定loss_fn = torch.nn.MSELoss(reduction='mean'),最终的输出值其实是(y - pred)每个元素数字的平方之和除以(bxc),也就是在batch和特征维度上都取了平均。
如果只想在batch上做平均,可以这样写:
loss_fn = torch.nn.MSELoss(reduction='sum')
loss = loss_fn(pred, y) / pred.size(0)