import torch
# 自定义数据
x = torch.randn((2,3))
y = torch.randn((2,3))
# 自定义MSE
def MSELoss(pred,target):
return torch.pow(pred-target,2)
# 实例化类
MSE = torch.nn.MSELoss(reduction="mean")
loss0 = MSELoss(x,y).mean()
loss1 = MSE(x,y)
print(loss0, loss1)
注意,loss0后面有 .mean(),与调用类的reduction参数一致