自定义pytorch训练框架1(图像处理):简单的自定义dataset

文章介绍了一个Python类Dataset,用于处理包含多种图像格式(如jpg、jpeg、png和tif)的目录。它提供方法加载图像,对.tif文件特别处理,如果遇到问题会抛出异常。类还实现了索引访问和获取图像文件路径列表的功能。
摘要由CSDN通过智能技术生成
import os
from PIL import Image
import numpy as np
import tifffile  # 导入tifffile库

class Dataset:
    def __init__(self, root_dir, image_extensions=('jpg', 'jpeg', 'png', 'tif')):
        self.root_dir = root_dir
        self.image_extensions = image_extensions
        self.image_paths = self._find_images()

    def _find_images(self):
        image_paths = []
        for ext in self.image_extensions:
            for root, _, files in os.walk(self.root_dir):
                for file in files:
                    if file.lower().endswith('.' + ext):
                        image_paths.append(os.path.join(root, file))
        return image_paths

    def _load_image(self, image_path):
        ext = image_path.lower().split('.')[-1]
        if ext in ('jpg', 'jpeg', 'png'):
            image = Image.open(image_path).convert('RGB')
        elif ext == 'tif':
            try:
                image = tifffile.imread(image_path)
            except Exception as e:
                print(f"Error loading TIFF file {image_path}: {e}")
                raise
        else:
            raise ValueError(f"Unsupported image format for {image_path}")
        return image

    def __getitem__(self, index):
        if index < len(self.image_paths):
            image_path = self.image_paths[index]
            image_data = self._load_image(image_path)
            return image_data
        else:
            raise IndexError("Index out of range.")

    def __len__(self):
        return len(self.image_paths)

# 使用示例
if __name__ == "__main__":
    dataset = Dataset(root_dir='path/to/your/dataset', image_extensions=('jpg', 'png', 'tif'))
    image_data, image_path = dataset[0]
    print(f"Image loaded from {image_path}")

    _load_image 方法现在会检查文件扩展名。如果文件是 .tif 格式,它会尝试使用 tifffile.imread 函数来读取图像。如果 tifffile 无法读取文件(例如,由于缺少依赖项或其他原因),它将抛出一个异常。如果文件是其他支持的格式(.jpg.jpeg.png),它将像之前一样使用 Pillow 来读取图像。

           请注意,tifffilepytiff 库的功能和接口可能有所不同。在这个例子中,我们使用了 tifffile,因为它提供了简单直接的 TIFF 文件读取功能。如果你选择使用 pytiff,则可能需要根据其API调整代码。

           _find_images 的方法,用于在指定的目录中查找所有图片文件,并将这些文件的路径添加到 image_paths 列表中。下面是对代码的逐行解释:

  1. def _find_images(self):

    • 定义一个名为 _find_images 的方法,该方法是一个实例方法(因为它有 self 参数)。
  2. image_paths = []

    • 初始化一个空列表 image_paths,用于存储找到的图片文件的路径。
  3. for ext in self.image_extensions:

    • 遍历 self.image_extensions 列表或元组中的每一个元素。self.image_extensions 应该是一个包含图片文件扩展名的列表或元组,例如 ['jpg', 'png', 'jpeg']
  4. for root, _, files in os.walk(self.root_dir):

    • 使用 os.walk 函数遍历 self.root_dir 目录及其所有子目录。os.walk 会为每个目录返回一个三元组:目录路径(root)、目录名列表(这里用 _ 忽略)和文件名列表(files)。
  5. for file in files:

    • 遍历当前目录下的所有文件名。
  6. if file.lower().endswith('.' + ext):

    • 检查当前文件名(转换为小写后)是否以指定的扩展名(前面加上 .)结尾。这样做是为了确保文件名匹配且不受大小写影响。
  7. image_paths.append(os.path.join(root, file))

    • 如果文件名匹配,使用 os.path.join 函数将目录路径和文件名合并为一个完整的文件路径,然后将这个路径添加到 image_paths 列表中。
  8. return image_paths

    • 返回 image_paths 列表,其中包含了所有找到的图片文件的路径。

这个 _find_images 方法的作用是在 self.root_dir 指定的目录及其所有子目录中查找所有具有 self.image_extensions 中指定扩展名的图片文件,并返回这些文件的路径列表。

  • 9
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值