import os
from PIL import Image
import torch.utils.data as data
import torchvision.transforms as transforms
import cv2
import numpy as np
import random
random.seed(2021)
class CamObjDataset(data.Dataset):
def __init__(self, image_root, gt_root, edge_root, trainsize):
self.trainsize = trainsize # 用于训练时图像的大小
# 图像 标签 边缘图像列表创建和排序
self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg')]
self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.png')]
self.edges = [edge_root + f for f in os.listdir(edge_root) if f.endswith('.png')]
# 对这些列表进行排序,以确保图像、标签和边缘图像的顺序一致
self.images = sorted(self.images)
self.gts = sorted(self.gts)
self.edges = sorted(self.edges)
# 过滤不匹配文件:移除尺寸不一致的图像、标签和边缘图像
self.filter_files()
self.size = len(self.images)
self.kernel = np.ones((5, 5), np.uint8)
# 用于图像和标签的转换操作,调整大小、转换为张量、标准化
self.img_transform = transforms.Compose([
transforms.Resize((self.trainsize, self.trainsize)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])])
self.ge_transform = transforms.Compose([
transforms.Resize((self.trainsize, self.trainsize)),
transforms.ToTensor()])
# 随机决定是否对图像进行水平翻转
def getFlip(self):
p = random.randint(0, 1)
self.flip = transforms.RandomHorizontalFlip(p)
# 按照索引获取数据集中的单个样本
def __getitem__(self, index):
self.getFlip()
# 加载图像、标签和边缘图像
image = self.rgb_loader(self.images[index])
gt = self.binary_loader(self.gts[index])
edge = cv2.imread(self.edges[index], cv2.IMREAD_GRAYSCALE)
image = self.flip(image)
image = self.img_transform(image)
gt = self.flip(gt)
gt = self.ge_transform(gt)
edge = cv2.dilate(edge, self.kernel, iterations=1)
edge = Image.fromarray(edge)
edge = self.flip(edge)
edge = self.ge_transform(edge)
return image, gt, edge
# 确保图像、标签和边缘图像尺寸一致
def filter_files(self):
assert len(self.images) == len(self.gts)
images = []
gts = []
edges = []
for img_path, gt_path, edge_path in zip(self.images, self.gts, self.edges):
img = Image.open(img_path)
gt = Image.open(gt_path)
edge = Image.open(edge_path)
if img.size == gt.size and img.size == edge.size:
images.append(img_path)
gts.append(gt_path)
edges.append(edge_path)
self.images = images
self.gts = gts
self.edges = edges
# 用于加载RGB图像和二值图像
def rgb_loader(self, path):
with open(path, 'rb') as f:
img = Image.open(f)
return img.convert('RGB')
def binary_loader(self, path):
with open(path, 'rb') as f:
img = Image.open(f)
return img.convert('L')
# 如果图像或标签小于trainsize,则调整其大小
def resize(self, img, gt):
assert img.size == gt.size
w, h = img.size
if h < self.trainsize or w < self.trainsize:
h = max(h, self.trainsize)
w = max(w, self.trainsize)
return img.resize((w, h), Image.BILINEAR), gt.resize((w, h), Image.NEAREST)
else:
return img, gt
def __len__(self):
return self.size
python数据增强——对数据预处理(随即裁剪、翻转)
最新推荐文章于 2024-07-12 19:06:42 发布