import logging
from os import listdir
from os.path import splitext
from pathlib import Path
import numpy as np
from PIL import Image
from torch.utils.data import Dataset
import torch
import os
import cv2 as cv
class BasicDataset(Dataset):
def __init__(self, images_dir: str, masks_dir: str, scale: float = 1.0, mask_suffix: str = ''):
self.images_dir = Path(images_dir)
self.masks_dir = Path(masks_dir)
assert 0 < scale <= 1, 'Scale must be between 0 and 1'
self.scale = scale
self.mask_suffix = mask_suffix
self.ids = [splitext(file)[0] for file in listdir(images_dir) if not file.startswith('.')] # 获取非点号开头的前缀名的列表
if not self.ids:
raise RuntimeError(f'No input file found in {images_dir}, make sure you put your images there')
logging.info(f'Creating dataset with {len(self.ids)} examples')
def __len__(self):
return len(self.ids)
@classmethod
def preprocess(cls, pil_img, scale, is_mask):
w, h = pil_img.size # 获取输入图片的宽、高
newW, newH = int(scale * w), int(scale * h) # 重塑图片的宽、高
assert newH > 0 and newW > 0, 'Scale is too small, resized images would have no pixel'
pil_img = pil_img.resize((newW, newH), resample=Image.NEAREST if is_mask else Image.BICUBIC) # 重塑图片(根据is_mask来选择重采样的方法)
img_ndarray = np.asarray(pil_img) # 获取图片矩阵的数据
if img_ndarray.ndim == 2 and not is_mask: # 判断图片矩阵的维数和is_mask
img_ndarray = img_ndarray[np.newaxis, ...] # 增加一维
elif not is_mask:
img_ndarray = img_ndarray.transpose((2, 0, 1)) # 图片矩阵转置
if not is_mask:
img_ndarray = img_ndarray / 255 # 像素值实数化(将像素值控制在0.0~1.0之间)
return img_ndarray
@classmethod
def load(cls, filename):
ext = splitext(filename)[1]
if ext in ['.npz', '.npy']:
return Image.fromarray(np.load(filename)) # np加载数据并将array转换成image
elif ext in ['.pt', '.pth']:
return Image.fromarray(torch.load(filename).numpy()) # torch加载数据并转成np最后转成image
else:
return Image.open(filename)
def __getitem__(self, idx):
name = self.ids[idx]
mask_file = list(self.masks_dir.glob(name + self.mask_suffix + '.*')) # 加载masks路径下相应name数据的mask标签列表(这里数据集加上_mask中间缀)
img_file = list(self.images_dir.glob(name + '.*')) # 加载images路径下相应name数据列表
assert len(mask_file) == 1, f'Either no mask or multiple masks found for the ID {name}: {mask_file}'
assert len(img_file) == 1, f'Either no image or multiple images found for the ID {name}: {img_file}'
# 加载数据转成PIL格式
mask = self.load(mask_file[0])
img = self.load(img_file[0])
assert img.size == mask.size, f'Image and mask {name} should be the same size, but are {img.size} and {mask.size}'
# 数据预处理
img = self.preprocess(img, self.scale, is_mask=False)
mask = self.preprocess(mask, self.scale, is_mask=True)
# 将数据转成tensor格式并占用新的内存
return {
'image': torch.as_tensor(img.copy()).float().contiguous(),
'mask': torch.as_tensor(mask.copy()).long().contiguous()
}
class CarvanaDataset(BasicDataset):
def __init__(self, images_dir, masks_dir, scale=1):
super().__init__(images_dir, masks_dir, scale, mask_suffix='')
unet read_data.py 解析
于 2022-07-23 10:29:34 首次发布