减去RGB均值(实例以DIV2K数据集为例)
在计算机视觉领域中,一定免不了的就是图像预处理中的 逐个样本减去mean值的过程,那么为什么一定要做这一步呢?
为什么每张图片都要减去数据集均值呢?
原因:为了进行数据特征标准化,即像机器学习中的特征预处理那样对输入特征向量各维去均值再除以标准差,但由于自然图像各点像素值的范围都在0-255之间,方差大致一样,只要做去均值(减去整个图像数据集的均值或各通道关于图像数据集的均值)处理即可。
主要原理:我们默认自然图像是一类平稳的数据分布(即数据每一维的统计都服从相同分布),此时,在每个样本上减去数据的统计平均值可以移除共同的部分,凸显个体差异。
其效果如下所示:
可以看到天空的纹理被移除了,凸显了汽车和高楼等主要特征。
最值得注意的一点是,在计算均值之前就要预先划分好训练集验证集和测试集,然后只针对训练集计算均值,否则就违背了深度学习的原则:模型训练过程仅能从训练模型中获取信息。得到训练集的均值后,对训练集验证集和测试集分别减去该均值。
以RCAN中代码为例:
class MeanShift(nn.Conv2d):
def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1):
super(MeanShift, self).__init__(3, 3, kernel_size=1)
std = torch.Tensor(rgb_std)
self.weight.data = torch.eye(3).view(3, 3, 1, 1) # #第一维为输出通道,第二维为输入通道
self.weight.data.div_(std.view(3, 1, 1, 1)) # torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1)
self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean)
self.bias.data.div_(std) # sign * rgb_range * torch.Tensor(rgb_mean) / std
# (w*x+b)/std
self.requires_grad = False
rgb_mean = (0.4488, 0.4371, 0.4040)
rgb_std = (1.0, 1.0, 1.0)
sub_mean = MeanShift(255, rgb_mean, rgb_std)
add_mean = MeanShift(255, rgb_mean, rgb_std, sign=1)
print(sub_mean)
print('sub_mean的所有属性方法', dir(sub_mean)) # dir(Object)可以获得对象的所有属性和方法
print('stride', getattr(sub_mean,'stride')) # getattr(object, name)获取对象object的属性或者方法
print('判断是否有weight属性', hasattr(sub_mean, 'weight')) # hasattr(Object, "name")可以判断对象是否有某个属性
x = imageio.imread('322878.jpg') # 载入图片
np_transpose = np.ascontiguousarray(x.transpose((2, 0, 1)))
tensor = torch.from_numpy(np_transpose).float()
tensor = tensor.unsqueeze(0)
tensor.mul_(255. / 255)
print('input',tensor, tensor.shape)
tensor = sub_mean(tensor)
print('output',tensor, tensor.shape)
tensor = add_mean(tensor)
print('reoutput', tensor, tensor.shape)
输出
MeanShift(3, 3, kernel_size=(1, 1), stride=(1, 1))
sub_mean的所有属性方法 ['__call__', '__class__', '__constants__', '__delattr__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattr__', '__getattribute__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__le__', '__lt__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__setstate__', '__sizeof__', '__str__', '__subclasshook__', '__weakref__', '_apply', '_backend', '_backward_hooks', '_buffers', '_forward_hooks&