import os
import numpy as np
import torchvision
import torchvision.datasets as datasets
import json
import torchvision.transforms as transforms
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
import sys
import cv2
class MyDataset(Dataset):
def __init__(self,
img_path,
txt_path,
img_transform=None):
with open(txt_path, 'r') as f:
lines = f.readlines()
self.img_list = [
os.path.join(img_path, i.split()[0]) for i in lines
]
self.label_list = [i.split()[1] for i in lines]
self.img_transform = img_transform
def __getitem__(self, index):
img_path = self.img_list[index]
label = self.label_list[index]
img = Image.open(img_path).convert('RGB')
if img.mode != 'RGB':
raise ValueError("image: {} isn't RGB mode.".format(self.image_list[index]))
if self.img_transform is not None:
img = self.img_transform(img)
return img, torch.from_numpy(np.array(label).astype(int)).long()
def __len__(self):
return len(self.label_list)
@staticmethod
def collate_fn(batch):
images, labels = tuple(zip(*batch))
images = torch.stack(images, dim=0)
labels = torch.as_tensor(labels)
return images, labels
def data_transforms_():
data_transform = {
"train": transforms.Compose([transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.61, 0.56, 0.52], [0.20, 0.21, 0.22])]),
"val": transforms.Compose([transforms.Resize(256),
transforms.CenterCrop(256),
transforms.ToTensor(),
transforms.Normalize([0.62, 0.57, 0.53], [0.21, 0.22, 0.23])])}
return data_transform
import torch
from torchvision.datasets import ImageFolder
from utils.MyDataset import MyDataset, data_transforms_
def getStat(train_data):
'''
Compute mean and variance for training data
:param train_data: 自定义类Dataset(或ImageFolder即可)
:return: (mean, std)
'''
print('Compute mean and variance for training data.')
print(len(train_data))
train_loader = torch.utils.data.DataLoader(
train_data, batch_size=1, shuffle=False, num_workers=0,
pin_memory=True)
mean = torch.zeros(3)
std = torch.zeros(3)
for X, _ in train_loader:
for d in range(3):
mean[d] += X[:, d, :, :].mean()
std[d] += X[:, d, :, :].std()
mean.div_(len(train_data))
std.div_(len(train_data))
return list(mean.numpy()), list(std.numpy())
if __name__ == '__main__':
train_dataset = MyDataset(img_path='data_set/Mydata/', txt_path='data_set/Mydata/val.txt',
img_transform=data_transforms_()['val'])
print(getStat(train_dataset))