import numpy as np
import torch
import torch.utils.data as data
import os
from PIL import Image
import random
from torchvision import transforms
def is_image_file(filename): # 定义一个判断是否是图片的函数
return any(filename.endswith(extension) for extension in [".png", ".jpg", ".jpeg",".tif"])
def data_augment(img1, img2, flip=1, ROTATE_90=1, ROTATE_180=1, ROTATE_270=1, add_noise=1):
n = flip + ROTATE_90 + ROTATE_180 + ROTATE_270 + add_noise
a = random.random()
if flip == 1:
img1 = img1.transpose(Image.FLIP_LEFT_RIGHT)
img2 = img2.transpose(Image.FLIP_LEFT_RIGHT)
if ROTATE_90 == 1:
img1 = img1.transpose(Image.ROTATE_90)
img2 = img2.transpose(Image.ROTATE_90)
if ROTATE_180 == 1:
img1 = img1.transpose(Image.ROTATE_180)
img2 = img2.transpose(Image.ROTATE_180)
if ROTATE_270 == 1:
img1 = img1.transpose(Image.ROTATE_270)
img2 = img2.transpose(Image.ROTATE_270)
if add_noise == 1:
pass
class train_dataset(data.Dataset):
def __init__(self, data_path='', size_w=256, size_h=256, flip=1):
super(train_dataset, self).__init__()
self.list = [x for x in os.listdir(data_path + '/src/') if is_image_file(x)]
self.data_path = data_path
self.size_w = size_w
self.size_h = size_h
self.flip = flip
self.transform1 = transforms.Compose([
transforms.ToTensor(), # 将图片转换为Tensor,归一化至[0,1]
transforms.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5]) # 标准化至[-1,1]
])
self.transform2 = transforms.Compose([
transforms.ToTensor(), # 将图片转换为Tensor,归一化至[0,1]
])
def __getitem__(self, index):
initial_path = os.path.join(self.data_path + '/src/', self.list[index])
semantic_path = os.path.join(self.data_path + '/label/', self.list[index])
assert os.path.exists(semantic_path)
try:
#initial_image = Image.open(initial_path).convert('RGB')
initial_image = Image.open(initial_path).convert('RGB')
#semantic_image = Image.open(semantic_path).point(lambda i: i * 80).convert('RGB')
semantic_image = Image.open(semantic_path).convert('L')
except OSError:
return None, None, None
#initial_image = initial_image.resize((self.size_w, self.size_h), Image.BILINEAR)
#semantic_image = semantic_image.resize((self.size_w, self.size_h), Image.BILINEAR)
if self.flip == 1:
a = random.random()
if a < 1 / 16:
initial_image = initial_image.transpose(Image.FLIP_LEFT_RIGHT)
semantic_image = semantic_image.transpose(Image.FLIP_LEFT_RIGHT)
elif a < 1/ 8:
initial_image = initial_image.transpose(Image.FLIP_TOP_BOTTOM)
semantic_image = semantic_image.transpose(Image.FLIP_TOP_BOTTOM)
elif a < 2 / 8:
initial_image = initial_image.transpose(Image.ROTATE_90)
semantic_image = semantic_image.transpose(Image.ROTATE_90)
elif a < 3 / 8:
initial_image = initial_image.transpose(Image.ROTATE_180)
semantic_image = semantic_image.transpose(Image.ROTATE_180)
elif a < 4 / 8:
initial_image = initial_image.transpose(Image.ROTATE_270)
semantic_image = semantic_image.transpose(Image.ROTATE_270)
initial_image = self.transform1(initial_image)
semantic_image = self.transform2(semantic_image)
return initial_image, semantic_image, self.list[index]
def __len__(self):
return len(self.list)
class val_dataset(data.Dataset):
def __init__(self, data_path='', size_w=256, size_h=256, flip=0):
super(val_dataset, self).__init__()
self.list = [x for x in os.listdir(data_path + '/val/') if is_image_file(x)]
self.data_path = data_path
self.flip = flip
self.transform1 = transforms.Compose([
transforms.ToTensor(), # 将图片转换为Tensor,归一化至[0,1]
transforms.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5]) # 标准化至[-1,1]
])
self.transform2 = transforms.Compose([
transforms.ToTensor(), # 将图片转换为Tensor,归一化至[0,1]
])
def __getitem__(self, index):
initial_path = os.path.join(self.data_path + '/val/', self.list[index])
semantic_path = os.path.join(self.data_path + '/vallabel/', self.list[index])
#assert os.path.exists(semantic_path)
try:
#initial_image = Image.open(initial_path).convert('RGB')
initial_image = Image.open(initial_path).convert('RGB')
#semantic_image = Image.open(semantic_path).point(lambda i: i * 80).convert('RGB')
semantic_image = Image.open(semantic_path).convert('L')
except OSError:
return None, None, None
#initial_image = initial_image.resize((self.size_w, self.size_h), Image.BILINEAR)
#semantic_image = semantic_image.resize((self.size_w, self.size_h), Image.BILINEAR)
if self.flip == 1:
a = random.random()
if a < 1 / 16:
initial_image = initial_image.transpose(Image.FLIP_LEFT_RIGHT)
semantic_image = semantic_image.transpose(Image.FLIP_LEFT_RIGHT)
elif a < 1/ 8:
initial_image = initial_image.transpose(Image.FLIP_TOP_BOTTOM)
semantic_image = semantic_image.transpose(Image.FLIP_TOP_BOTTOM)
elif a < 2 / 8:
initial_image = initial_image.transpose(Image.ROTATE_90)
semantic_image = semantic_image.transpose(Image.ROTATE_90)
elif a < 3 / 8:
initial_image = initial_image.transpose(Image.ROTATE_180)
semantic_image = semantic_image.transpose(Image.ROTATE_180)
elif a < 4 / 8:
initial_image = initial_image.transpose(Image.ROTATE_270)
semantic_image = semantic_image.transpose(Image.ROTATE_270)
val_image = self.transform1(initial_image)
label_image = self.transform2(semantic_image)
return val_image, label_image, self.list[index]
def __len__(self):
return len(self.list)
数据加载data
最新推荐文章于 2021-11-27 21:30:47 发布