我们在使用模型训练之前一般要对数据进行归一化(Normalize),归一化之前需要得到数据集整体的方差和均值,这里提供了一个简单计算数据标准差和均值的接口,方便大家使用。
def get_mean_std(dataset, ratio=0.01):
"""Get mean and std by sample ratio
"""
dataloader = torch.utils.data.DataLoader(dataset, batch_size=int(len(dataset)*ratio),
shuffle=True, num_workers=10)
train = iter(dataloader).next()[0] # 一个batch的数据
mean = np.mean(train.numpy(), axis=(0,2,3))
std = np.std(train.numpy(), axis=(0,2,3))
return mean, std
其思想主要是随机从数据集采样,直接调用numpy的方法返回数据集样本的均值和方差。
以下以CIFAR10数据集为例。
# cifar10
train_dataset = torchvision.datasets.CIFAR10('./data',
train=True, download=False,
transform=transforms.ToTensor())
test_dataset = torchvision.datasets.CIFAR10('./data',
train=False, download=False,
transform=transforms.ToTensor())
train_mean, train_std = get_mean_std(train_dataset)
test_mean, test_std = get_mean_std(test_dataset)
print(train_mean, train_std)
print(test_mean,test_std)
打印结果为: