小问题不要慌!!!!
运行代码:
import sys
sys.path.append('..')
import torch
def simple_batch_norm_1d(x, gamma, beta):
eps = 1e-5
x_mean = torch.mean(x, dim=0, keepdim=True) # dim=0在每一列上求取均值 保留维度进行 broadcast
x_var = torch.mean((x - x_mean) ** 2, dim=0, keepdim=True)
x_hat = (x - x_mean) / torch.sqrt(x_var + eps)
return gamma.view_as(x_mean) * x_hat + beta.view_as(x_mean)
# 5行3列表示三个特征,每个特征上有五个数据点
x = torch.arange(15).view(5, 3)
gamma = torch.ones(x.shape[1])
beta = torch.zeros(x.shape[1])
print('before bn: ')
print(x)
y = simple_batch_norm_1d(x, gamma, beta)
y = y.float()
print('after bn: ')
print(y)
该代码是学习pytorch数据标准化的代码,对一个tensor求一个均值和方差。
报错如下:
该错误提示也很明显,在求均值的时候数据类型不对,计算得到的是个long型,对其数据类型做个转换即可。
修改如下:
x_mean = torch.mean(x.float(), dim=0, keepdim=True)
这是运行就没错误啦!!!!