import os
from PIL import Image
import torch.utils.data as data
import torchvision.transforms as transforms
import cv2
import numpy as np
import random
random.seed(2021)
class CamObjDataset(data.Dataset):
def __init__(self, image_root, gt_root, edge_root, trainsize):
self.trainsize = trainsize # 用于训练时图像的大小
# 图像 标签 边缘图像列表创建和排序
self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg')]
self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.png')]
self.edges = [edge_root + f for f in os.listdir(edge_root) if f.endswith('.png')]
# 对这些列表进行排序,以确保图像、标签和边缘图像的顺序一致
self.images = sorted(self.images)
self.gts = sorted(self.gts)
self.edges = sorted(self.edges)
# 过滤不匹配文件:移除尺寸不一致的图像、标签和边缘图像
self.filter_files()
self.size = len(self.images)