自定DatasetLoad(数据加载器)以及一些图像增强方法--笔记LOG

#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
    @Project :pt_tf_lea 
    @Author  :Anjou
    @Date    :2023/5/15 13:38 
"""
import os
from torch.utils.data import Dataset
from PIL import Image
import numpy as np
import random
import PIL
import torch
from torchvision import transforms as T
from torchvision.transforms import functional as F


def pad_if_smaller(img, size, fill=0):
    # 如果图像最小边长小于给定size,则用数值fill进行padding
    min_size = min(img.size)
    if min_size < size:
        ow, oh = img.size
        padh = size - oh if oh < size else 0
        padw = size - ow if ow < size else 0
        img = F.pad(img, (0, 0, padw, padh), fill=fill)
    return img


class Compose(object):
    # 构建处理图像的transform的处理pipeline
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, image, target):
        for t in self.transforms:
            image, target = t(image, target)
        return image, target


class RandomResize(object):
    def __init__(self, min_size, max_size=None):
        self.min_size = min_size
        if max_size is None:
            max_size = min_size
        self.max_size = max_size

    def __call__(self, image, target):
        size = random.randint(self.min_size, self.max_size)
        # 这里size传入的是int类型,所以是将图像的最小边长缩放到size大小
        image = F.resize(image, size)
        # 这里的interpolation注意下,在torchvision(0.9.0)以后才有InterpolationMode.NEAREST
        # 如果是之前的版本需要使用PIL.Image.NEAREST
        target = F.resize(target, size, interpolation=PIL.Image.NEAREST)
        return image, target


class RandomHorizontalFlip(object):
    # 随机翻转图像
    def __init__(self, flip_prob):
        self.flip_prob = flip_prob

    def __call__(self, image, target):
        if random.random() < self.flip_prob:
            image = F.hflip(image)
            target = F.hflip(target)
        return image, target


class RandomCrop(object):
    # 随机裁剪图像
    def __init__(self, size):
        self.size = size

    def __call__(self, image, target):
        # 首先要确定所裁剪图像不要小于标准图像
        image = pad_if_smaller(image, self.size)
        target = pad_if_smaller(target, self.size, fill=255)
        # 得到随机裁剪的参数,返回坐标x,y 和 裁剪框的h, w
        crop_params = T.RandomCrop.get_params(image, (self.size, self.size))
        image = F.crop(image, *crop_params)
        target = F.crop(target, *crop_params)
        return image, target


class CenterCrop(object):
    # 中心裁剪
    def __init__(self, size):
        self.size = size

    def __call__(self, image, target):
        image = F.center_crop(image, self.size)
        target = F.center_crop(target, self.size)
        return image, target


class ToTensor(object):
    def __call__(self, image, target):
        image = F.to_tensor(image)
        target = torch.as_tensor(np.array(target), dtype=torch.int64)
        return image, target


class Normalize(object):
    # 图像标准化,设定均值和方差,减均值除方差,将数据标准化为正态分布
    def __init__(self, mean, std):
        self.mean = mean
        self.std = std

    def __call__(self, image, target):
        image = F.normalize(image, mean=self.mean, std=self.std)
        return image, target


class loadDataset(Dataset):
    "自定义数据集加载器"

    def __init__(self, ROOT_IMAGE: str, ROOT_TARGET: str, TRANSFORM=None):
        """
        :param ROOT_IMAGE: 图像目录
        :param ROOT_TARGET: GT目录
        :param TRANSFORM: 增广方法
        """
        self.imagePaths = [os.path.join(ROOT_IMAGE, i) for i in os.listdir(ROOT_IMAGE)]
        self.targetPaths = [os.path.join(ROOT_TARGET, i) for i in os.listdir(ROOT_TARGET)]
        self.imagePaths.sort()  # 对两者排序确认数据对应
        self.targetPaths.sort()
        self.transform = TRANSFORM

    def __getitem__(self, item):
        image = Image.open(self.imagePaths[item])
        if image.mode is not 'RGB':
            raise ValueError(f'{self.imagePaths[item]} is not RGB mode')
        target = self.targetPaths[item]
        if self.transform:
            image = self.transform(image)
        return image, target

    def __len__(self):
        return len(self.imagePaths)

    @staticmethod
    def collect_fn(batch):
        "兼容不同大小图像"
        images, targets = list(zip(*batch))
        batched_imgs = cat_list(images, fillValue=0)
        # 如果为mask,则填充,否则不做处理
        # if mask:
        #     batched_targets = cat_list(targets, fillValue=255)
        batched_targets = targets
        return batched_imgs, batched_targets


def cat_list(images, fillValue=0):
    maxSize = tuple(max(s) for s in zip(*[img.shape for img in images]))  # 获取batch图像中最大尺寸的c,h和w
    batch_shape = (len(images),) + maxSize  # 变为批次维度 NCHW
    batched_imgs = images[0].new(*batch_shape).fill_(fillValue)  # 创建batch_shape同纬度的mask蒙版
    for image, pad_image in zip(images, batched_imgs):  # 为蒙版填充原图像,将批次内不同大小的图像统一为最大底图像(相当于填充满边框)
        pad_image[..., :image.shape[-2], :image.shape[-1]].copy_(image)
    return batched_imgs


class TransformTrain:
    def __init__(self, base_size, crop_size, hflip_prob=0.5, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
        max_size = int(1.5 * base_size)
        min_size = int(0.5 * base_size)
        trans_list = [RandomResize(min_size, max_size)]
        if hflip_prob > 0:
            trans_list.append(RandomHorizontalFlip(hflip_prob))
        trans_list.extend(
            [RandomCrop(crop_size),
             ToTensor(),
             Normalize(mean, std)]
        )
        self.transforms = T.Compose(trans_list)

    def __call__(self, image, target):
        return self.transforms(image, target)


class TransformVal:
    def __init__(self, base_size, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
        self.transforms = Compose([
            RandomResize(base_size, base_size),
            ToTensor(),
            Normalize(mean, std)
        ])

    def __call__(self, img, target):
        return self.transforms(img, target)


def get_transform(train):
    base_size = 520
    crop_size = 480
    return TransformTrain(base_size, crop_size) if train else TransformVal(base_size)

以上内容作为备忘,需要的小伙伴自取咯~~~

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值