import os
from PIL import Image
import numpy as np
import tifffile # 导入tifffile库
class Dataset:
def __init__(self, root_dir, image_extensions=('jpg', 'jpeg', 'png', 'tif')):
self.root_dir = root_dir
self.image_extensions = image_extensions
self.image_paths = self._find_images()
def _find_images(self):
image_paths = []
for ext in self.image_extensions:
for root, _, files in os.walk(self.root_dir):
for file in files:
if file.lower().endswith('.' + ext):
image_paths.append(os.path.join(root, file))
return image_paths
def _load_image(self, image_path):
ext = image_path.lower().split('.')[-1]
if ext in ('jpg', 'jpeg', 'png'):
image = Image.open(image_path).convert('RGB')
elif ext == 'tif':
try:
image = tifffile.imread(image_path)
except Exception as e:
print(f"Error loading TIFF file {image_path}: {e}")
raise
else:
raise ValueError(f"Unsupported image format for {image_path}")
return image
def __getitem__(self, index):
if index < len(self.image_paths):
image_path = self.image_paths[index]
image_data = self._load_image(image_path)
return image_data
else:
raise IndexError("Index out of range.")
def __len__(self):
return len(self.image_paths)
# 使用示例
if __name__ == "__main__":
dataset = Dataset(root_dir='path/to/your/dataset', image_extensions=('jpg', 'png', 'tif'))
image_data, image_path = dataset[0]
print(f"Image loaded from {image_path}")
_load_image
方法现在会检查文件扩展名。如果文件是 .tif
格式,它会尝试使用 tifffile.imread
函数来读取图像。如果 tifffile
无法读取文件(例如,由于缺少依赖项或其他原因),它将抛出一个异常。如果文件是其他支持的格式(.jpg
、.jpeg
、.png
),它将像之前一样使用 Pillow 来读取图像。
请注意,tifffile
和 pytiff
库的功能和接口可能有所不同。在这个例子中,我们使用了 tifffile
,因为它提供了简单直接的 TIFF 文件读取功能。如果你选择使用 pytiff
,则可能需要根据其API调整代码。
_find_images
的方法,用于在指定的目录中查找所有图片文件,并将这些文件的路径添加到 image_paths
列表中。下面是对代码的逐行解释:
-
def _find_images(self):
- 定义一个名为
_find_images
的方法,该方法是一个实例方法(因为它有self
参数)。
- 定义一个名为
-
image_paths = []
- 初始化一个空列表
image_paths
,用于存储找到的图片文件的路径。
- 初始化一个空列表
-
for ext in self.image_extensions:
- 遍历
self.image_extensions
列表或元组中的每一个元素。self.image_extensions
应该是一个包含图片文件扩展名的列表或元组,例如['jpg', 'png', 'jpeg']
。
- 遍历
-
for root, _, files in os.walk(self.root_dir):
- 使用
os.walk
函数遍历self.root_dir
目录及其所有子目录。os.walk
会为每个目录返回一个三元组:目录路径(root
)、目录名列表(这里用_
忽略)和文件名列表(files
)。
- 使用
-
for file in files:
- 遍历当前目录下的所有文件名。
-
if file.lower().endswith('.' + ext):
- 检查当前文件名(转换为小写后)是否以指定的扩展名(前面加上
.
)结尾。这样做是为了确保文件名匹配且不受大小写影响。
- 检查当前文件名(转换为小写后)是否以指定的扩展名(前面加上
-
image_paths.append(os.path.join(root, file))
- 如果文件名匹配,使用
os.path.join
函数将目录路径和文件名合并为一个完整的文件路径,然后将这个路径添加到image_paths
列表中。
- 如果文件名匹配,使用
-
return image_paths
- 返回
image_paths
列表,其中包含了所有找到的图片文件的路径。
- 返回
这个 _find_images
方法的作用是在 self.root_dir
指定的目录及其所有子目录中查找所有具有 self.image_extensions
中指定扩展名的图片文件,并返回这些文件的路径列表。