在深度学习与图像有关的任务中,我们经常需要计算图像数据集(特别是私有图像数据集)三通道的均值和方差来对图像进行norm处理。下面以CIFAR10数据集为例来展示如何计算图像数据集三通道的均值和方差:
import torch
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
def get_mean_std(loader):
sum_X_per_channel, sum_X_square_per_channel, num_batches = 0, 0, 0
for X, _ in loader:
sum_X_per_channel += torch.mean(X, dim=[0, 2, 3])
sum_X_square_per_channel += torch.mean(X ** 2, dim=[0, 2, 3])
num_batches += 1
mean = sum_X_per_channel / num_batches
mean_square = sum_X_square_per_channel / num_batches
# VAR[X] = E[X**2] - E(X)**2
std = (mean_square - mean ** 2) ** 0.5
return mean, std
if __name__ == '__main__':
train_transform = transforms.Compose([
transforms.ToTensor(),
])
train_set = datasets.CIFAR10(root='/media/s5/gj/nodecode/CV/data', train=True, download=True, transform=train_transform)
train_loader = DataLoader(train_set, batch_size=100, shuffle=True)
mean, std = get_mean_std(train_loader)
print("mean:", mean)
print("std:", std)
# tensor([0.4914, 0.4822, 0.4465])
# tensor([0.2470, 0.2435, 0.2616])