计算pytorch标准化(Normalize)所需要数据集的均值和方差

先说明一下情况

1、如果是自己的数据集,mean 和 std 肯定要在normalize之前自己先算好再传进去的

2、有两种情况:
a)数据集在加载的时候就已经转换成了[0, 1].
b)应用了torchvision.transforms.ToTensor,其作用是
( Converts a PIL Image or numpy.ndarray (H x W x C) in the range [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] )

3、[0.485, 0.456, 0.406]这一组平均值一般是抽样算出来的。

如何计算数据集的mean和std

数据存放路径:

在这里插入图片描述

calculate_mead_and_std.py

import random
from PIL import Image
from torch.utils.data import Dataset
import numpy as np
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import os

random.seed(1)
# dict_label:类别对应表
# dict_label = {"airplane": 0, "automobile": 1, "bird": 2, "cat": 3, "deer": 4,"dog": 5,
#               "frog": 6, "horse": 7, "ship": 8, "truck": 9}
dict_label = {"0": 0, "1": 1, "2": 2, "3": 3, "4": 4,"5": 5,
              "6": 6, "7": 7, "8": 8, "9": 9}  # 如果改了分类目标,这里需要修改



def get_img_info(data_dir):
    data_info = list()
    for root, dirs, _ in os.walk(data_dir):
        # 遍历类别
        for sub_dir in dirs:
            img_names = os.listdir(os.path.join(root, sub_dir))
            img_names = list(filter(lambda x: x.endswith('.png'), img_names)) # 过滤,剩下.png结尾的文件名
            # 遍历图片
            for i in range(len(img_names)):
                img_name = img_names[i]
                path_img = os.path.join(root, sub_dir, img_name) # 完整图片路径
                label = dict_label[sub_dir] # 获取当前图片的标签
                data_info.append((path_img, int(label))) # 返回 [(path_img1,label1),(path_img2,label2),...]

    return data_info

class MyDataset(Dataset):
    def __init__(self, data_dir, transform=None):
        self.label_name = dict_label
        self.data_info = get_img_info(data_dir)  # data_info存储所有图片路径和标签
        self.transform = transform

    def __getitem__(self, index):
        path_img, label = self.data_info[index]
        img = Image.open(path_img).convert('RGB')     # 0~255
        if self.transform is not None:
            img = self.transform(img)   # 在这里做transform,转为tensor等等
        return img, label

    def __len__(self):
        return len(self.data_info)


#  指定计算mean和std的图像数据集路径
train_dir = os.path.join('.', 'split_data',"train")

#  图像预处理
train_transform = transforms.Compose([
    transforms.Resize((32, 32)), # 可以改成你图片近似大小或者模型要求大小
    transforms.ToTensor(),
])

train_data = MyDataset(data_dir=train_dir, transform=train_transform)
train_loader = DataLoader(dataset=train_data, batch_size=3000, shuffle=True) # 3000张图片的mean std
train = iter(train_loader).next()[0]  # 3000张图片的mean、std
train_mean = np.mean(train.numpy(), axis=(0, 2, 3))
train_std = np.std(train.numpy(), axis=(0, 2, 3))

print("train_mean:",train_mean)
print("train_std:",train_std)



在这里插入图片描述

## [numpy中的mean和std中axis的使用](https://blog.csdn.net/qq_41375318/article/details/115439612)
  • 4
    点赞
  • 23
    收藏
    觉得还不错? 一键收藏
  • 3
    评论
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值