求标准化所需的mean和std的三种方法
什么是标准化?
这里所说的标准化是针对pytorch中的 transforms.Normalize(mean=..., std=...)
函数而言,即 z-score标准化 :
x
∗
=
x
−
x
ˉ
σ
x^* = \frac{x - \bar{x}}{\sigma}
x∗=σx−xˉ
下面给出方法代码,注意,由于我使用的数据集是灰度图片,所以计算一个通道就可以了 。如果使用RGB图片,只需额外在循环语句里加几行代码以计算另外两个通道的相关参数就行了。
方法一
求整个数据集每张图片的均值、标准差之和,再除以数据集里图片个数作为整个数据集的标准化所需的参数: 其基本思想如下图:
代码如下:
import os
import numpy as np
from PIL import Image
if __name__ == '__main__':
filepath = r"C:/Users/xxx/images/" # 数据集目录
pathDir = os.listdir(filepath) # 数据集目录下图片
num = len(pathDir) # 这里(512,512)是每幅图片的尺寸
print("Computing mean...")
data_mean = 0.
for idx in range(len(pathDir)):
filename = pathDir[idx]
img = Image.open(os.path.join(filepath, filename)).convert('L')
img = np.array(img) / 255.0
data_mean += np.mean(img) # 取三维矩阵中第一维的所有数据
# 由于使用的是灰度图片,所以计算一个通道就可以了
data_mean = data_mean / num
print("Computing var...")
data_std = 0.
for idx in range(len(pathDir)):
filename = pathDir[idx]
img = Image.open(os.path.join(filepath, filename)).convert('L')
img = np.array(img) / 255.0
data_std += np.std(img)
data_std = data_std / num
print("mean:{}".format(data_mean))
print("std:{}".format(data_std))
方法二
求整个数据集所有图片的均值和标准差: 基本思想就是将数据集中所有图片给拼接成一个大图片,求这个大图片的均值和标准差作为整个数据集的标准化所需的参数。代码如下:
import os
import numpy as np
from PIL import Image
if __name__ == '__main__':
filepath = r"C:/Users/xxx/images/" # 数据集目录
pathDir = os.listdir(filepath) # 数据集目录下图片
num = len(pathDir) * 512 * 512 # 数据集目录下图片尺寸
print("Computing mean...")
data_mean = 0.
for idx in range(len(pathDir)):
filename = pathDir[idx]
img = Image.open(os.path.join(filepath, filename)).convert('L')
img = np.array(img)
img = img.astype(np.float64) / 255.0
data_mean += np.sum(img)
# 由于使用的是灰度图片,所以计算一个通道就可以了
data_mean = data_mean / num
print("Computing std...")
data_std = 0.
for idx in range(len(pathDir)):
filename = pathDir[idx]
img = Image.open(os.path.join(filepath, filename)).convert('L')
img = np.array(img)
img = img.astype(np.float64) / 255.0
data_std += np.sum((img - data_mean) ** 2)
data_std = np.sqrt(data_std / num)
print("mean:{}".format(data_mean))
print("std:{}".format(data_std))
方法三
计算单张图片的均值标准差,每张图片各自标准化。就是说每张图片标准化时使用的mean和std都是根据当前读入的图片,实时计算出来的mean和std,从而对该张图片进行标准化。 代码如下:
def data_transform(image_object, label_object):
tensor_transform = transforms.ToTensor()
image_object = tensor_transform(image_object)
# 由于使用的是灰度图片,所以计算一个通道就可以了
image_mean = image_object[0, :, :].mean()
image_std = image_object[0, :, :].std()
normalize_transform = transforms.Normalize(mean=image_mean, std=image_std)
image_object = normalize_transform(image_object)
if label_object is not None:
label_object = tensor_transform(label_object)
return image_object, label_object
return image_object
说明
- 方法三给的是一个函数,只需要在构建数据集的时候调用这个函数就行了。方法一、二的、则可以视为单独对数据集的一个分析脚本,计算出来的mean和std直接在数据集构建的时候放入
transforms.Normalize(mean=..., std=...)
参数即可。 - 根据概率论与数理统计的知识,方法一和方法二计算出来的均值是相同的,但是标准差并不相同 ,严格来说,方法二计算出来的标准差才是整个数据集真正的标准差,方法一计算出来的标准差只是在数据集的图片分布满足一定条件时,对真正的标准差的一个近似估计。(表达可能有误,我对概率论那边也不是很熟,如果说错了,请提醒我纠正,谢谢)
- 方法三提供的方法是我自己想出来的,因为在使用方法一、二计算出来的均值标准差进行标准化的时候,有的图片中的很多像素值都会超过[0, 1]这个范围,个人感觉可能不太好,于是写出来了这个函数,对每张图片各自进行标准化。而在实际测试时,发现经此函数归一化后的也是可以顺利训练出具有良好性能的模型的。放在这里并不代表其完全符合数据集标准化的思想,只是作为一种可行的方法。