get_mean_and_std.py
import torch
from tqdm import tqdm
def get_mean_and_std(dataset):
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2)
mean = torch.zeros(3)
std = torch.zeros(3)
print('===>Computing mean and std..')
for inputs,targets in tqdm(dataloader):
for i in range(3):
mean[i] += inputs[:, i, :, :].mean()
std[i] += inputs[:, i, :, :].std()
mean.div_(len(dataset))
std.div_(len(dataset))
return mean,std
调用
get_mean_and_std(train_dataset)