自己测试的可以:
import torch
x=torch.arange(256).view(4,4,4,4).float()
x_mean=torch.mean(x,axis=[2,3],keepdim=True)
nu2 =input.pow(2).mean(dim=[2, 3], keepdim=True)
这句在老版本中报错:
nu2 = torch.mean(input.pow(2), axis=[2, 3], keepdims=True)
error:
{TypeError}mean() received an invalid combination of arguments - got (Tensor, dim=list, keepdims=bool), but expected one of:
* (Tensor input)
* (Tensor input, torch.dtype dtype)
* (Tensor input, tuple of ints dim, torch.dtype dtype, Tensor out)
* (Tensor input, tuple of ints dim, bool keepdim, torch.dtype dtype, Tensor out)
* (Tensor input, tuple of ints dim, bool keepdim, Tensor out)
原因是keepdims 参数错了,应该是keepdim.