1.数据标准化
Normalize()函数的作用是将数据转换为标准高斯分布,即逐个 c h a n n e l channel channel的对图像进行标准化(均值变为 0 0 0,标准差为 1 1 1),可以加快模型的收敛,具体的采
- m e a n mean mean:各通道的均值
- s t d std std:各通道的标准差
- i n p l a c e inplace inplace:是否原地操作
- o u t p u t [ c h a n n e l ] = i n p u t [ c h a n n e l ] − m e a n [ c h a n n e l ] s t d [ c h a n n e l ] output[channel]=\frac{input[channel]-mean[channel]}{std[channel]} output[channel]=std[channel]input[channel]−mean[channel]
2.代码实例
经常看到的 m e a n = [ 0.485 , 0.456 , 0.406 ] mean=[0.485, 0.456, 0.406] mean=[0.485,0.456,0.406], s t d = [ 0.229 , 0.224 , 0.225 ] std=[0.229, 0.224, 0.225] std=[0.229,0.224,0.225]表示的是从数据集中随机抽样计算得到的。
from torchvision import models, transforms
# 迁移学习,预训练模型
net = models.resnet18(pretrained=True)
# 数据转换
image_transform = transforms.Compose([
# 将输入图片resize成统一尺寸
transforms.Resize([224, 224]),
# 将PIL Image或numpy.ndarray转换为tensor,并除255归一化到[0,1]之间
transforms.ToTensor(),
# 标准化处理-->转换为标准正太分布,使模型更容易收敛
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])