相信我们都见到过这样一段代码:
这里的transforms.Normalize函数里的mean和std的值就是我们的ImageNet的均值和方差。但是我也见到过不采用这些值得均值与方差,于是我思考怎么计算自己数据集图片的均值与方差。下面将会给出计算过程的代码,代码实现部分出自这本书:
这里我以一个小的奥特曼数据集为例计算其均值方差:
代码实现:
import os
import torch
import imageio
import numpy as np
batch_size = 138 #你数据集里的图片个数
h,w = 256, 256 #因为我们的transforms里normalize前往往跟着一个裁减,所以我们这里也要先裁减再计算
batch = torch.zeros(batch_size, 3, h, w, dtype=torch.uint8)
data_dir = r'D:\Dataset\AoteMan\train\奥特曼'
print('图片个数:',len(os.listdir(data_dir)))
filenames = [name for name in os.listdir(data_dir)]
for i, filename in enumerate(filenames):
img_arr = imageio.imread(os.path.join(data_dir, filename))
img_arr = np.resize(img_arr,(h,w,3)) #裁减
img_t = torch.from_numpy(img_arr)
img_t = img_t.permute(2,0,1) # h,w,c -> c,h,w
img_t = img_t[:3] #取前三个通道
batch[i] = img_t
batch = batch.float()
batch /= 255.
n_channels = batch.shape[1]
means, stds = [], []
for c in range(n_channels):
mean = torch.mean(batch[:,c])
means.append(mean)
std = torch.std(batch[:,c])
stds.append(std)
batch[:,c] = (batch[:,c] - mean) / std
print(means)
print(stds)
结果: