pytorch报错: Can only calculate the mean of floating types. Got Long instead

小问题不要慌!!!!
运行代码:

import sys
sys.path.append('..')
import torch

def simple_batch_norm_1d(x, gamma, beta):
    eps = 1e-5
    x_mean = torch.mean(x, dim=0, keepdim=True)  # dim=0在每一列上求取均值  保留维度进行 broadcast
    x_var = torch.mean((x - x_mean) ** 2, dim=0, keepdim=True)
    x_hat = (x - x_mean) / torch.sqrt(x_var + eps)
    return gamma.view_as(x_mean) * x_hat + beta.view_as(x_mean)

#  5行3列表示三个特征,每个特征上有五个数据点
x = torch.arange(15).view(5, 3)
gamma = torch.ones(x.shape[1])
beta = torch.zeros(x.shape[1])
print('before bn: ')
print(x)
y = simple_batch_norm_1d(x, gamma, beta)
y = y.float()
print('after bn: ')
print(y)

该代码是学习pytorch数据标准化的代码,对一个tensor求一个均值和方差。
报错如下:
在这里插入图片描述
该错误提示也很明显,在求均值的时候数据类型不对,计算得到的是个long型,对其数据类型做个转换即可。
修改如下:

x_mean = torch.mean(x.float(), dim=0, keepdim=True) 

这是运行就没错误啦!!!!

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

臭皮匠-hfW

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值