unet read_data.py 解析

import logging
from os import listdir
from os.path import splitext
from pathlib import Path

import numpy as np
from PIL import Image

from torch.utils.data import Dataset
import torch
import os
import cv2 as cv

class BasicDataset(Dataset):
    def __init__(self, images_dir: str, masks_dir: str, scale: float = 1.0, mask_suffix: str = ''):
        self.images_dir = Path(images_dir)
        self.masks_dir = Path(masks_dir)
        assert 0 < scale <= 1, 'Scale must be between 0 and 1'
        self.scale = scale
        self.mask_suffix = mask_suffix

        self.ids = [splitext(file)[0] for file in listdir(images_dir) if not file.startswith('.')]  # 获取非点号开头的前缀名的列表
        if not self.ids:
            raise RuntimeError(f'No input file found in {images_dir}, make sure you put your images there')
        logging.info(f'Creating dataset with {len(self.ids)} examples')

    def __len__(self):
        return len(self.ids)

    @classmethod
    def preprocess(cls, pil_img, scale, is_mask):
        w, h = pil_img.size  # 获取输入图片的宽、高
        newW, newH = int(scale * w), int(scale * h)  # 重塑图片的宽、高
        assert newH > 0 and newW > 0, 'Scale is too small, resized images would have no pixel'
        pil_img = pil_img.resize((newW, newH), resample=Image.NEAREST if is_mask else Image.BICUBIC)  # 重塑图片(根据is_mask来选择重采样的方法)
        img_ndarray = np.asarray(pil_img)  # 获取图片矩阵的数据

        if img_ndarray.ndim == 2 and not is_mask:  # 判断图片矩阵的维数和is_mask
            img_ndarray = img_ndarray[np.newaxis, ...]  # 增加一维
        elif not is_mask:
            img_ndarray = img_ndarray.transpose((2, 0, 1))  # 图片矩阵转置

        if not is_mask:
            img_ndarray = img_ndarray / 255  # 像素值实数化(将像素值控制在0.0~1.0之间)

        return img_ndarray

    @classmethod
    def load(cls, filename):
        ext = splitext(filename)[1]
        if ext in ['.npz', '.npy']:
            return Image.fromarray(np.load(filename))  # np加载数据并将array转换成image
        elif ext in ['.pt', '.pth']:
            return Image.fromarray(torch.load(filename).numpy())  # torch加载数据并转成np最后转成image
        else:
            return Image.open(filename)

    def __getitem__(self, idx):
        name = self.ids[idx]
        mask_file = list(self.masks_dir.glob(name + self.mask_suffix + '.*'))  # 加载masks路径下相应name数据的mask标签列表(这里数据集加上_mask中间缀)
        img_file = list(self.images_dir.glob(name + '.*'))  # 加载images路径下相应name数据列表

        assert len(mask_file) == 1, f'Either no mask or multiple masks found for the ID {name}: {mask_file}'
        assert len(img_file) == 1, f'Either no image or multiple images found for the ID {name}: {img_file}'

        # 加载数据转成PIL格式
        mask = self.load(mask_file[0])
        img = self.load(img_file[0])

        assert img.size == mask.size, f'Image and mask {name} should be the same size, but are {img.size} and {mask.size}'

        # 数据预处理
        img = self.preprocess(img, self.scale, is_mask=False)
        mask = self.preprocess(mask, self.scale, is_mask=True)

        # 将数据转成tensor格式并占用新的内存
        return {
            'image': torch.as_tensor(img.copy()).float().contiguous(),
            'mask': torch.as_tensor(mask.copy()).long().contiguous()
        }

class CarvanaDataset(BasicDataset):
    def __init__(self, images_dir, masks_dir, scale=1):
        super().__init__(images_dir, masks_dir, scale, mask_suffix='')



评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值