关于pytorch中transform.Normalize的理解

        之前对于使用pytorch中的transforms.Normalize进行图像的标准化原理一直存在困扰,今天通过学习以下文章明白了一些,做一个记录。

参考链接地址:(21条消息) pytorch中归一化transforms.Normalize的真正计算过程_月光下的小白菜的博客-CSDN博客_transform里的normalize


        在训练中使用图像标准化主要是为了在梯度下降中算法可以更好的进行寻优,对于这一原理可以参考吴恩达的以下课程。

课程链接:【中英字幕】吴恩达深度学习课程第二课 — 改善深层神经网络:超参数调试、正则化以及优化_哔哩哔哩_bilibili

        使用pytorch进行图像标准化的代码一般如下:

data=transforms.ToTensor()(img)
data=transforms.Normalize(mean,std)(data)

         其中第一行是将读取到的图片的维度进行转换(W,H,C转换为C,W,H)并将每个像素值除以255,第二行代码是通过如下公式进行标准化。

x=(x-mean)/std

        其中mean为均值std为标准差。

        下边参考第一条链接文章中的代码详细的分析:

        首先自己根据公式进行标准化计算:

import torch
import numpy as np
from torchvision import transforms
#模拟读取到的一张图片,数据类型一定要是uin8否则不会进行/255的归一化
#'uint8'为最大值为255的专门用来储存图片信息的数据类型
data=np.array([
    [[1,1,1],[1,1,1],[1,1,1],[1,1,1],[1,1,1],],
    [[2,2,2],[2,2,2],[2,2,2],[2,2,2],[2,2,2],],
    [[3,3,3],[3,3,3],[3,3,3],[3,3,3],[3,3,3],],
    [[4,4,4],[4,4,4],[4,4,4],[4,4,4],[4,4,4],],
    [[5,5,5],[5,5,5],[5,5,5],[5,5,5],[5,5,5],],
],dtype='uint8')
#将w,h,c转化为c,w,h
data_to=transforms.ToTensor()(data)
#将data加上batch的维度,在实际中应对整个数据集进行统计
data=torch.unsqueeze(data_to,0)
N,C,H,W=data.shape[:4]
#根据三通道展平数据
data=data.view(N,C,-1)
#计算均值以及标准差
channel_mean=torch.zeros(3)
channel_std=torch.zeros(3)
channel_mean += data.mean(2).sum(0)
channel_std +=  data.std(2).sum(0)
print(channel_mean)
print(channel_std)

        打印均值及标准差:

tensor([[0.0118, 0.0118, 0.0118]])
tensor([[0.0057, 0.0057, 0.0057]])

        根据公式将输入数据进行标准化处理: 

for i in range(3):
    data_to[i,:,:]=(data_to[i,:,:]-channel_mean[i])/channel_std[i]
print(data_to)

        输出结果: 

tensor([[[-1.3856, -1.3856, -1.3856, -1.3856, -1.3856],
         [-0.6928, -0.6928, -0.6928, -0.6928, -0.6928],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.6928,  0.6928,  0.6928,  0.6928,  0.6928],
         [ 1.3856,  1.3856,  1.3856,  1.3856,  1.3856]],

        [[-1.3856, -1.3856, -1.3856, -1.3856, -1.3856],
         [-0.6928, -0.6928, -0.6928, -0.6928, -0.6928],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.6928,  0.6928,  0.6928,  0.6928,  0.6928],
         [ 1.3856,  1.3856,  1.3856,  1.3856,  1.3856]],

        [[-1.3856, -1.3856, -1.3856, -1.3856, -1.3856],
         [-0.6928, -0.6928, -0.6928, -0.6928, -0.6928],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.6928,  0.6928,  0.6928,  0.6928,  0.6928],
         [ 1.3856,  1.3856,  1.3856,  1.3856,  1.3856]]])

        使用pytorch提供的方法进行计算: 

data=np.array([
    [[1,1,1],[1,1,1],[1,1,1],[1,1,1],[1,1,1],],
    [[2,2,2],[2,2,2],[2,2,2],[2,2,2],[2,2,2],],
    [[3,3,3],[3,3,3],[3,3,3],[3,3,3],[3,3,3],],
    [[4,4,4],[4,4,4],[4,4,4],[4,4,4],[4,4,4],],
    [[5,5,5],[5,5,5],[5,5,5],[5,5,5],[5,5,5],],
],dtype='uint8')
data=transforms.ToTensor()(data)
data=transforms.Normalize(channel_mean,channel_std)(data)
print(data)

         输出结果: 

tensor([[[-1.3856, -1.3856, -1.3856, -1.3856, -1.3856],
         [-0.6928, -0.6928, -0.6928, -0.6928, -0.6928],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.6928,  0.6928,  0.6928,  0.6928,  0.6928],
         [ 1.3856,  1.3856,  1.3856,  1.3856,  1.3856]],

        [[-1.3856, -1.3856, -1.3856, -1.3856, -1.3856],
         [-0.6928, -0.6928, -0.6928, -0.6928, -0.6928],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.6928,  0.6928,  0.6928,  0.6928,  0.6928],
         [ 1.3856,  1.3856,  1.3856,  1.3856,  1.3856]],

        [[-1.3856, -1.3856, -1.3856, -1.3856, -1.3856],
         [-0.6928, -0.6928, -0.6928, -0.6928, -0.6928],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.6928,  0.6928,  0.6928,  0.6928,  0.6928],
         [ 1.3856,  1.3856,  1.3856,  1.3856,  1.3856]]])

         对比发现两次计算的结果相同。

 

  • 4
    点赞
  • 17
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值