深度学习数据集均值求和

# @Author:Fangwenxuan
import os

import numpy as np
import torchvision
import torchvision.datasets as datasets
import json
import torchvision.transforms as transforms
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
import sys
import cv2


class MyDataset(Dataset):
    def __init__(self,
                 img_path,
                 txt_path,
                 img_transform=None):
        with open(txt_path, 'r') as f:
            lines = f.readlines()
            self.img_list = [
                os.path.join(img_path, i.split()[0]) for i in lines
            ]
            self.label_list = [i.split()[1] for i in lines]
        self.img_transform = img_transform

    def __getitem__(self, index):
        img_path = self.img_list[index]
        label = self.label_list[index]
        img = Image.open(img_path).convert('RGB')
        if img.mode != 'RGB':
            raise ValueError("image: {} isn't RGB mode.".format(self.image_list[index]))
        # img = self.loader(img_path)
        if self.img_transform is not None:
            img = self.img_transform(img)
        return img, torch.from_numpy(np.array(label).astype(int)).long()

    def __len__(self):
        return len(self.label_list)

    @staticmethod
    def collate_fn(batch):
        images, labels = tuple(zip(*batch))
        images = torch.stack(images, dim=0)
        labels = torch.as_tensor(labels)
        return images, labels


def data_transforms_():
    data_transform = {
        "train": transforms.Compose([transforms.RandomResizedCrop(224),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize([0.61, 0.56, 0.52], [0.20, 0.21, 0.22])]),
        "val": transforms.Compose([transforms.Resize(256),
                                   transforms.CenterCrop(256),
                                   transforms.ToTensor(),
                                   transforms.Normalize([0.62, 0.57, 0.53], [0.21, 0.22, 0.23])])}
    return data_transform

# @Author:Fangwenxuan
import torch
from torchvision.datasets import ImageFolder

from utils.MyDataset import MyDataset, data_transforms_


def getStat(train_data):
    '''
    Compute mean and variance for training data
    :param train_data: 自定义类Dataset(或ImageFolder即可)
    :return: (mean, std)
    '''
    print('Compute mean and variance for training data.')
    print(len(train_data))
    train_loader = torch.utils.data.DataLoader(
        train_data, batch_size=1, shuffle=False, num_workers=0,
        pin_memory=True)
    mean = torch.zeros(3)
    std = torch.zeros(3)
    for X, _ in train_loader:
        for d in range(3):
            mean[d] += X[:, d, :, :].mean()
            std[d] += X[:, d, :, :].std()
    mean.div_(len(train_data))
    std.div_(len(train_data))
    return list(mean.numpy()), list(std.numpy())


# [0.6175137, 0.56862676, 0.52036345], [0.20779464, 0.21386808, 0.22357221]
# ([0.62248206, 0.5772133, 0.5316572], [0.21813951, 0.22444752, 0.23638827])
if __name__ == '__main__':
    train_dataset = MyDataset(img_path='data_set/Mydata/', txt_path='data_set/Mydata/val.txt',
                              img_transform=data_transforms_()['val'])
    print(getStat(train_dataset))

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值