pytorch学习 | 如何统计数据集的均值和标准差?

我们在使用模型训练之前一般要对数据进行归一化(Normalize),归一化之前需要得到数据集整体的方差和均值,这里提供了一个简单计算数据标准差和均值的接口,方便大家使用。

def get_mean_std(dataset, ratio=0.01):
    """Get mean and std by sample ratio
    """
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=int(len(dataset)*ratio), 
                                             shuffle=True, num_workers=10)
    train = iter(dataloader).next()[0]   # 一个batch的数据
    mean = np.mean(train.numpy(), axis=(0,2,3))
    std = np.std(train.numpy(), axis=(0,2,3))
    return mean, std

其思想主要是随机从数据集采样,直接调用numpy的方法返回数据集样本的均值和方差。

以下以CIFAR10数据集为例。

# cifar10
train_dataset = torchvision.datasets.CIFAR10('./data', 
                                             train=True, download=False, 
                                             transform=transforms.ToTensor())
test_dataset = torchvision.datasets.CIFAR10('./data', 
                                           train=False, download=False, 
                                            transform=transforms.ToTensor())

train_mean, train_std = get_mean_std(train_dataset)

test_mean, test_std = get_mean_std(test_dataset)

print(train_mean, train_std)
print(test_mean,test_std)

打印结果为:
在这里插入图片描述

  • 10
    点赞
  • 36
    收藏
    觉得还不错? 一键收藏
  • 7
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 7
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值