玩转CV数据集:从获取、分析到PyTorch实战
第一部分:数据集的获取
在计算机视觉的征途上,找到合适的数据集是项目成功的第一步。幸运的是,我们拥有众多优秀的平台和资源,它们不仅提供了丰富的数据,还往往附带了便捷的下载工具和方法。
1.1. 主流CV数据集平台概览
计算机视觉领域的数据集资源非常丰富,分布在各种类型的平台上。了解这些平台的特点和获取方式,能帮助我们更快地找到研究所需的数据。
1.1.1. Kaggle:数据科学竞赛与数据集中心
Kaggle (www.kaggle.com) 不仅是全球知名的数据科学竞赛平台,也是一个庞大的公共数据集存储库。这里汇集了大量用户上传和官方维护的数据集,覆盖了从图像分类、目标检测到图像分割等多种CV任务。
获取方式:
- 网页直接下载:用户可以直接在Kaggle网站上浏览并下载数据集。
- Kaggle API/CLI:对于大型数据集或需要自动化下载的场景,Kaggle提供了官方API和命令行工具(CLI)。首先,你需要在Kaggle账户设置中创建API Token,并将下载的
kaggle.json
文件放置在指定目录(Linux/OSX通常是~/.kaggle/kaggle.json
,Windows是C:\Users\<Windows-username>\.kaggle\kaggle.json
)。之后便可以使用命令行进行操作。- 列出数据集:
kaggle datasets list -s [关键词]
- 下载数据集:
kaggle datasets download -d [owner_username]/[dataset_slug]
。例如,下载著名的“Dogs vs. Cats”数据集:
或者下载一个特定的社区数据集:kaggle competitions download -c dogs-vs-cats
下载时可以使用kaggle datasets download -d anasmohammedtahir/celeba-dataset
-p <路径>
指定下载目录,--unzip
参数可以在下载后自动解压。
- 列出数据集:
Kaggle平台上的数据集通常由社区成员或竞赛组织者提供,因此数据质量和标注规范性可能有所不同。在使用前,仔细阅读数据集描述和讨论区是非常重要的。
1.1.2. Hugging Face Hub:模型与数据集的开源社区
Hugging Face Hub (huggingface.co/datasets) 已经成为机器学习领域,特别是自然语言处理和计算机视觉模型与数据集的重要集散地。它不仅托管了大量的预训练模型,还提供了对众多CV数据集的便捷访问。
获取方式:
datasets
库:这是Hugging Face官方推荐的与数据集交互的方式。通过load_dataset
函数,可以非常方便地加载Hub上的数据集,甚至是一些本地数据集。该库底层使用Apache Arrow格式进行高效的数据存储和读取,支持内存映射,对处理大规模数据集非常友好。from datasets import load_dataset # 例如,加载一个Hugging Face Hub上的图像分类数据集 # food_dataset = load_dataset("food101", split="train[:5000]") # print(food_dataset) # 加载LaSOT数据集 (如果其在Hub上有直接的加载脚本) # lasot_dataset = load_dataset("l-lt/LaSOT", trust_remote_code=True) # trust_remote_code可能需要,如果它使用了自定义脚本 # print(lasot_dataset)
load_dataset
函数支持多种参数,如name
(选择数据集的特定配置或子集)、split
(选择训练/验证/测试集,甚至切片)和data_dir
(指定数据文件子目录)。它会自动处理下载、缓存和解析过程。snapshot_download()
:对于那些不仅仅是数据文件,还包含其他辅助脚本或特定目录结构的完整数据集仓库,可以使用huggingface_hub
库中的snapshot_download
函数。它可以下载整个仓库到本地。
参数包括:repo_id
(仓库ID,如"l-lt/LaSOT"
),repo_type="dataset"
,local_dir
(本地保存路径),以及allow_patterns
和ignore_patterns
用于选择性下载文件(例如,只下载.safetensors
权重文件或特定类型的标注文件)。
示例 (LaSOT):
这种方式对于获取如LaSOT这类包含完整结构和评估工具的基准数据集非常有用。from huggingface_hub import snapshot_download # lasot_path = snapshot_download(repo_id="l-lt/LaSOT", repo_type="dataset", local_dir="./lasot_data") # print(f"LaSOT downloaded to: {lasot_path}")
- 特定CV数据集:Hugging Face Hub上有许多知名的CV数据集,例如LaSOT, Conceptual Captions,以及各种COCO数据集的变体或子集。
Hugging Face Hub的一大优势在于其生态系统的整合性,datasets
库不仅简化了数据获取,还与transformers
库紧密配合,方便后续的模型训练和评估。其内置的数据集查看器(Dataset Viewer)也使得在线浏览和理解数据集结构变得容易。
1.1.3. 学术机构与研究门户
许多开创性的CV数据集最初都源于学术界的研究项目。
- PapersWithCode (paperswithcode.com/datasets):这是一个将科研论文、代码实现和相关数据集紧密联系起来的优秀平台。它本身通常不直接托管数据集,而是提供指向原始数据源的链接。你可以根据任务(如目标检测、语义分割)、模态(图像、视频)或论文来发现和追踪最新的数据集。
- 大学/实验室网站:很多顶尖大学和研究实验室会维护自己的数据集发布页面。例如,LaSOT数据集的主页就设在石溪大学计算机视觉实验室(CVL)的网站上。这类网站通常提供直接的下载链接,可能指向OneDrive、Google Drive或百度网盘等云存储服务。
- VisualData.io (visualdata.io):这是一个专门为计算机视觉领域设计的数据集搜索引擎,用户可以通过主题和关键词来查找数据集。它更侧重于发现,然后引导用户到数据集的原始托管位置。
- Google Dataset Search (datasetsearch.research.google.com):这是一个通用的数据集搜索引擎,它依赖于数据集发布者提供的schema.org元数据来索引全球的开放数据集。它不直接托管数据,而是提供链接。
从学术渠道获取数据集,通常能保证数据的权威性和与最新研究的紧密关联性。但下载方式和数据组织形式可能不如商业平台那样标准化。
1.1.4. 专业化数据集集合
针对特定的CV研究方向,也涌现出了一些专业化的数据集平台和集合。
- MOTChallenge (motchallenge.net):这是多目标跟踪(Multiple Object Tracking)领域的权威基准测试平台。它提供了一系列数据集(如MOT17、MOT20),包含图像序列、真实的标注(ground truth)以及公开的检测结果,用于评估跟踪算法的性能。数据集通常以ZIP压缩包的形式直接从其官网下载。
- LAION (laion.ai):LAION组织发布了多个大规模的开放图文对数据集,如LAION-400M和LAION-5B。LAION-5B包含高达58.5亿个经过CLIP模型过滤的图文对。由于其巨大的规模,直接下载全量图像对大多数用户不现实。因此,LAION主要提供元数据(包含图片URL、文本描述、CLIP相似度得分等),这些元数据以高效的Parquet文件格式存储。官方推荐使用
img2dataset
工具根据这些URL下载实际的图像数据。 - Roboflow Universe (universe.roboflow.com):Roboflow提供了一个包含大量公开数据集的平台,许多数据集都经过了预处理或带有标注,用户可以直接搜索和使用。Roboflow还提供了数据增强、模型训练和部署等一系列工具。数据集可以通过其API或CLI进行下载。例如,使用CLI下载的命令格式为:
roboflow download -f <格式> -l <下载位置> <workspaceId>/<projectId>/<versionNumber>
。
这些专业化数据集的出现,极大地推动了特定领域(如大规模预训练、多目标跟踪)的研究进展。它们通常伴随着专门的工具链和获取方法,以应对其独特的规模和数据特性。例如,LAION数据集通过提供元数据和img2dataset
工具,使得研究者可以按需下载和处理数据,而不是一次性下载整个TB级别的数据集,这体现了对大规模数据处理方式的一种转变。
表1: 主要CV数据集平台与访问方法
平台 | 主要访问方法 | 常见CV数据集示例 |
---|---|---|
Kaggle | 网页UI, Kaggle API/CLI | COCO子集, ImageNet子集, 各种社区贡献数据集 |
Hugging Face Hub | datasets 库 (load_dataset , snapshot_download ) | LaSOT, Conceptual Captions, COCO (多种版本), 其他CV任务数据集 |
PapersWithCode | 链接到原始数据源 | 广泛的研究数据集索引 |
大学/实验室网站 (如Stony Brook CVL) | 直接下载链接 (HTTP/FTP/云存储如OneDrive, Google Drive) | LaSOT, Nucleus classification dataset |
MOTChallenge | 官网直接下载 (通常为ZIP文件) | MOT17, MOT20 |
LAION | 元数据 (Parquet) + img2dataset 工具下载图像 | LAION-5B, LAION-400M |
Roboflow Universe | Roboflow API/CLI, 网页UI | Roboflow 100, 众多社区标注和增强的数据集 |
Google Dataset Search | 搜索引擎, 链接到数据源 | 广泛的跨学科数据集索引 |
VisualData.io | 搜索引擎, 链接到数据源 | 计算机视觉特定数据集索引 |
1.2. 高效下载大规模数据集的工具与技巧
面对动辄GB甚至TB级别的数据集,高效的下载工具和策略至关重要。
1.2.1. 标准命令行工具
对于直接提供HTTP/FTP下载链接的数据集,以下命令行工具能显著提升下载效率和稳定性:
wget
:- 基础用法:
wget <URL>
- 从文件列表批量下载:
wget -i urls.txt
- 断点续传:
wget -c <URL>
(对于大文件下载中断后非常有用) - 指定下载目录:
wget -P /path/to/directory <URL>
- 处理需要认证的下载(例如使用cookies):
wget --load-cookies cookies.txt <URL>
- 基础用法:
curl
:- 基础用法(保存为原始文件名):
curl -O <URL>
或curl --remote-name <URL>
- 自定义保存文件名:
curl -o new_filename.zip <URL>
- 自动跟随重定向:
curl -L <URL>
- 批量下载多个文件:
curl -O <URL1> -O <URL2>
- 基础用法(保存为原始文件名):
aria2c
:aria2c
是一款强大的多协议(HTTP/HTTPS, FTP, SFTP, BitTorrent, Metalink)多源下载工具,通过并发连接显著提高下载速度。- 基础用法:
aria2c <URL>
- 指定并发连接数:
aria2c -s16 <URL>
(每个服务器最多16个连接) 或-x16 <URL>
(单个文件总共最多16个连接) - 断点续传:
aria2c -c <URL>
- 从文件列表批量下载:
aria2c -i urls.txt
对于大型数据集,特别是那些以单个大文件形式提供下载的,aria2c
的多线程能力能够充分利用网络带宽,大幅缩短下载时间。
1.2.2. 专用Python库/工具
- Hugging Face
huggingface_hub
库:hf_hub_download()
:用于下载Hub上的单个文件,并进行缓存管理。主要参数包括repo_id
,filename
,repo_type
,revision
,cache_dir
。snapshot_download()
:用于下载整个模型或数据集仓库,支持并发下载和缓存。关键参数有repo_id
,repo_type
,revision
,cache_dir
,local_dir
,以及强大的allow_patterns
和ignore_patterns
用于选择性下载文件。例如,可以只下载仓库中的.safetensors
权重文件或特定子目录的内容。
这些工具专为Hugging Face Hub设计,能很好地处理版本控制和缓存,避免重复下载,并且通过模式匹配实现按需下载,这对于大型仓库尤其重要。
img2dataset
:- 这是LAION等大规模图文对数据集(它们通常只提供元数据,包含图片URL和文本描述)的核心下载工具。它能从URL列表高效下载图片,并进行预处理(如调整大小),最终以webdataset等格式存储,便于后续训练。
- 关键命令行参数:
--url_list
: 元数据文件路径(如Parquet文件目录或TXT文件)。--input_format
: 元数据文件格式(如"parquet"
,"tsv"
,"csv"
)。--url_col
: 元数据中图片URL所在的列名 (如"URL"
)。--caption_col
: 元数据中图片描述所在的列名 (如"TEXT"
)。--output_format
: 输出格式,通常为"webdataset"
。--output_folder
: 下载数据保存目录。--processes_count
: 使用的CPU进程数。--thread_count
: 每个进程的线程数。--image_size
: 调整图像大小。--save_additional_columns
: 需要一同保存的其他元数据列名列表。
- 示例:
img2dataset --url_list laion400m-meta --input_format "parquet" \ --url_col "URL" --caption_col "TEXT" --output_format webdataset \ --output_folder laion400m-data --processes_count 16 --thread_count 128 \ --image_size 256 --save_additional_columns '["similarity", "WIDTH", "HEIGHT"]'
img2dataset
的设计理念是应对海量URL列表的图像下载需求。它不仅下载,还能在下载过程中进行图像缩放、格式转换等预处理,并打包成适合大规模分布式训练的webdataset格式,极大地简化了从原始URL到可用训练数据的流程。
第二部分:经典CV数据集深度解析
了解了如何获取数据集后,我们来看看几个在CV领域具有里程碑意义的数据集,分析它们的特点、规模、主要任务以及数据是如何存储的。
2.1. 核心数据集剖析(CV)
2.1.1. ImageNet
ImageNet 是一个根据WordNet层级结构组织的大规模图像数据库。最初包含超过1400万张图像,覆盖约22000个synset(同义词集)。其中,最为人熟知的是ImageNet大规模视觉识别挑战赛(ILSVRC)所使用的数据集子集,该子集包含约120万张训练图像、5万张验证图像和10万张测试图像,共1000个对象类别。
- 主要任务:图像分类是其最核心的任务,同时也支持目标检测/定位任务。
- 图像格式:图像主要以JPEG格式存储。
- 标注格式:
- 图像分类:通常采用图像级标签,即每个图像对应一个类别(synset ID)。在实际存储中,训练图像常被组织在以类别ID命名的子文件夹下。
- 目标检测/定位 (ILSVRC):标注信息以PASCAL VOC类似的XML文件格式存储,每个图像对应一个XML文件。
- XML关键标签包括:
<folder>
,<filename>
,<size>
(包含<width>
,<height>
,<depth>
), 以及每个对象的<object>
标签。 <object>
标签内部包含:<name>
(通常是WordNet ID, WNID),<pose>
,<truncated>
,<difficult>
, 以及定义边界框的<bndbox>
(包含<xmin>
,<ymin>
,<xmax>
,<ymax>
)。这些坐标是绝对像素值。- 需要注意的是,XML中标注的图像尺寸可能与实际图像文件尺寸略有出入。
- XML关键标签包括:
- 获取途径:完整数据集的下载需要在ImageNet官网注册并同意使用条款。ILSVRC子集也可以在Kaggle上找到。下载的文件通常是
.tar
压缩包,如ILSVRC2012_img_train.tar
。
2.1.2. COCO (Common Objects in Context)
COCO (Common Objects in Context) 是一个为目标检测、实例分割、关键点检测、场景理解(stuff segmentation)、全景分割和图像描述生成等多种复杂视觉任务设计的大规模数据集。它强调“上下文中的常见物体”,包含约33万张图像(其中超过20万张有标注),标注了超过150万个对象实例。
- 主要任务:目标检测、实例分割、关键点检测、图像描述等。
- 图像格式:主要是JPEG格式。数据集通常划分为train2017、val2017、test2017等子集,图像分别存放在对应年份的文件夹中。
- 标注格式:JSON。COCO采用统一的JSON文件来存储一个数据分割(如训练集或验证集)的所有标注信息,例如
instances_train2017.json
用于目标检测和分割,captions_train2017.json
用于图像描述。其核心结构如下:"info"
: 包含数据集的描述、版本、年份、贡献者、创建日期等元信息。"licenses"
: 包含图像使用的许可证列表,每个许可证有唯一的id、name和url。"images"
: 图像信息列表,每个图像对象包含id
(图像唯一标识符)、width
(图像宽度)、height
(图像高度)、file_name
(图像文件名)、license
(许可证ID)、coco_url
、flickr_url
、date_captured
(拍摄日期) 等字段。"categories"
: 类别信息列表,每个类别对象包含id
(类别唯一标识符)、name
(类别名称,如"person"
,"car"
) 和supercategory
(父类别名称,如"vehicle"
)。COCO定义了80个“物体”(thing)类别和91个“背景材质”(stuff)类别。"annotations"
: 标注信息列表,是核心部分,其结构因任务类型而异:- 通用字段: 每个标注对象都有
id
(标注的唯一ID)、image_id
(关联到images列表中的图像ID)、category_id
(关联到categories列表中的类别ID,图像描述任务除外)。 - 目标检测 (Bounding Boxes):
"bbox"
: 一个包含四个数字的列表[x, y, width, height]
,表示边界框的左上角x坐标、左上角y坐标、宽度和高度(单位:像素)。"area"
: 边界框的面积(像素单位)。"iscrowd"
: 标记是否为对象群组。0表示单个对象,1表示对象群组(例如一群人)。"segmentation"
: 即使是目标检测任务,也可能包含分割信息。
- 实例分割 (Instance Segmentation):
"segmentation"
:- 当
iscrowd=0
(单个对象)时,通常是一个或多个多边形顶点列表,格式为[[x1,y1,x2,y2,...]]
。如果一个对象有多个不连续部分,则为[[polygon1_vertices], [polygon2_vertices]]
。 - 当
iscrowd=1
(对象群组)时,通常使用行程长度编码(Run-Length Encoding, RLE)对象来表示分割掩码。RLE对象是一个字典,包含"size": [height, width]
(图像的原始高和宽) 和"counts": [rle_counts_array]
(一个整数数组,表示交替的0和1的游程长度)。
- 当
"area"
: 分割掩码的面积(像素单位)。"bbox"
: 包围分割掩码的边界框[x,y,width,height]
。
- 图像描述 (Image Captioning):
"image_id"
: 描述所针对的图像ID。"id"
: 该描述标注的唯一ID。"caption"
: 包含描述文本的字符串。例如:{"id": 1796, "image_id": 9, "caption": "A black and white photo of a large crowd of people."}
。通常每张图片有5条描述。
- 通用字段: 每个标注对象都有
表3: COCO Annotation JSON - 关键部分与字段
主要部分 | 字段名 | 数据类型 | 描述 | 示例 (来自文献) |
---|---|---|---|---|
info | description | string | 数据集描述 | "Example COCO Dataset" |
url | string | 数据集URL | "https://www.example.com/dataset" | |
version | string | 数据集版本 | "1.0" | |
year | integer | 数据集年份 | 2023 | |
contributor | string | 贡献者 | "Sarah Connor" | |
date_created | string | 创建日期 (YYYY-MM-DD) | "1964-09-14" | |
licenses | id | integer | 许可证ID | 1 |
name | string | 许可证名称 | "Attribution-NonCommercial-ShareAlike License" | |
url | string | 许可证URL | "https://creativecommons.org/licenses/by-nc-sa/4.0/" | |
images | id | integer | 图像唯一ID | 397133 |
width | integer | 图像宽度 (像素) | 640 | |
height | integer | 图像高度 (像素) | 427 | |
file_name | string | 图像文件名 | "000000397133.jpg" | |
license | integer | 许可证ID (关联到licenses部分) | 4 | |
coco_url | string | COCO图像URL | "http://images.cocodataset.org/val2017/000000397133.jpg" | |
flickr_url | string | Flickr图像URL | "http://farm7.staticflickr.com/6116/6255196340_da26cf2c9e_z.jpg" | |
date_captured | string | 图像拍摄日期 | "2013-11-14 17:02:52" | |
categories | id | integer | 类别唯一ID | 1 |
name | string | 类别名称 | "person" | |
supercategory | string | 父类别名称 | "animal" (for "giraffe" ) | |
annotations | (通用) id | integer | 标注唯一ID | 1768 |
image_id | integer | 关联的图像ID | 289343 | |
category_id | integer | 关联的类别ID (检测/分割/关键点) | 18 (e.g., "dog" ) | |
(检测) bbox | list (float) | [x, y, width, height] 边界框坐标 | [473.07, 395.93, 38.65, 28.67] | |
area | float | 边界框/分割区域面积 | 702.1057 | |
iscrowd | 0 or 1 | 是否为对象群组 (0:否, 1:是) | 0 | |
(分割) segmentation | list/dict | 多边形 [[x1,y1,...]] 或 RLE {"size": [h,w], "counts":} | [[510.66, 423.01,...]] or {"counts": [179,27,...], "size": ...} | |
(描述) caption | string | 图像描述文本 | "A black and white photo of a large crowd of people." |
2.1.3. LaSOT (Large-scale Single Object Tracking)
LaSOT 是一个专为大规模单目标长时程跟踪任务设计的高质量基准数据集。会议版本包含1400个视频序列,总计超过350万帧图像;完整版则有1550个序列,超过387万帧。它涵盖了85个不同的物体类别,每个类别包含一定数量的序列(例如,70个类别各有20个序列,另外15个类别各有10个序列)。
- 主要任务:长时程单目标跟踪。LaSOT的视频平均长度约为2500帧(约83秒),这远超许多早期跟踪基准,旨在评估跟踪器在目标可能长时间消失后重现等复杂情况下的鲁棒性。
- 图像格式:视频序列由连续的图像帧组成,通常是JPEG或PNG格式。
- 标注格式:
- 每个视频序列都配有一个名为
groundtruth.txt
的文本文件,用于存储逐帧的标注信息。 groundtruth.txt
文件中的每一行对应视频中的一帧。- 每行数据由逗号分隔,包含以下信息:
x,y,width,height,full_occlusion,out_of_view
x,y,width,height
: 目标边界框的坐标,其中(x,y)
是边界框左上角的坐标,width
和height
分别是边界框的宽度和高度。full_occlusion
: 一个标志位,指示目标是否被完全遮挡(例如,0表示未遮挡或部分遮挡,1表示完全遮挡)。out_of_view
: 一个标志位,指示目标是否移出视野(例如,0表示在视野内,1表示已出画)。
- 除了边界框和状态标志,LaSOT还为每个视频提供了自然语言描述,这为结合语言特征进行跟踪研究提供了可能性。
- 每个视频序列都配有一个名为
- 获取途径:可从官方网站(vision.cs.stonybrook.edu/~lasot/)通过OneDrive、Google Drive、百度网盘等链接下载。同时,LaSOT数据集(包括会议版
l-lt/LaSOT
和扩展版l-lt/LaSOT-ext
)也可以在Hugging Face Hub上找到并使用snapshot_download
下载。
2.1.4. Conceptual Captions & LAION
- 规模与范围:
- Conceptual Captions (CC):主要版本包括CC3M(约330万图文对)和CC12M(约1200万图文对)。其文本描述主要来源于网页图片的alt-text属性,经过一系列自动化流程进行提取、过滤和转换,以平衡描述的简洁性、信息量、流畅性和可学习性。
- LAION:规模更为庞大,例如LAION-400M包含4亿图文对,而LAION-5B则达到了惊人的58.5亿图文对。这些图文对同样从网络抓取(主要来自Common Crawl),并使用CLIP模型的图文相似度评分进行过滤,以保证图像和文本在语义上的相关性。
- 主要任务:这类数据集主要用于训练视觉-语言预训练模型(Vision-Language Models),支持的任务包括图像描述生成、文本到图像生成、图文检索等。
- 数据表示:
- Conceptual Captions:通常以TSV(Tab-Separated Values)文件形式提供,主要包含两列:
caption
(文本描述,已进行分词和转小写处理)和image_url
(图片的网络链接)。部分版本可能还包含由机器生成的图像标签。用户需要自行根据URL下载图片。 - LAION:元数据(包括图片URL、文本描述、图像宽度、高度、检测到的语言、CLIP图文相似度得分、NSFW(不适宜内容)概率、水印概率等)以Parquet文件格式存储。Parquet是一种高效的列式存储格式,非常适合处理如此大规模的元数据。图像同样需要用户根据URL下载,官方推荐使用
img2dataset
工具。
- Conceptual Captions:通常以TSV(Tab-Separated Values)文件形式提供,主要包含两列:
- 获取途径:
- Conceptual Captions:Google Research的GitHub页面提供了TSV文件的下载链接。此外,也可以在Hugging Face Hub 和Kaggle 上找到。
- LAION:元数据(Parquet文件)可在Hugging Face Hub(例如
laion/laion2B-en-joined
)和LAION官方项目页获取。img2dataset
是下载和处理图像的标准工具。
表2: 主流CV数据集概览
数据集名称 | 主要任务 | 大致规模 (图像/图文对 & 类别) | 典型图像格式 | 典型标注格式 | 核心特点/关注点 |
---|---|---|---|---|---|
ImageNet | 图像分类, 目标检测 | ILSVRC: ~120万图像 / 1000类 | JPEG | XML (检测), 文件夹名 (分类) | 层级化分类 (WordNet), 经典基准 |
COCO | 目标检测, 实例分割, 关键点检测, 图像描述等 | ~20万+标注图像 / 80+91类 | JPEG | JSON (包含bbox, polygon, RLE, caption) | 上下文中的物体, 多任务支持, 精细标注 |
LaSOT | 长时程单目标跟踪 | ~350万帧 / 85类 | 图像序列 | TXT (每行: x,y,w,h,遮挡标志,出视野标志) | 长时程, 密集标注, 遮挡/出视野处理 |
Conceptual Captions | 视觉-语言预训练, 图像描述, 图文检索 | CC3M: ~330万图文对; CC12M: 1200万图文对 | 图像URL | TSV (caption, image_url) | 网络规模的图文对, alt-text来源 |
LAION | 视觉-语言预训练, 文本到图像生成 | LAION-5B: 58.5亿图文对 | 图像URL | Parquet (元数据含URL, 文本, 相似度等), 使用img2dataset下载图像 | 超大规模图文对, CLIP过滤, 支持多语言 |
2.2. 常见数据存储约定
理解数据集中图像和标注的存储方式,对于后续的数据加载和处理至关重要。
2.2.1. 图像文件格式
- PNG (Portable Network Graphics):
- 用途:由于其无损压缩特性,PNG非常适合存储分割掩码(segmentation masks)。在分割任务中,每个像素的精确值代表了其类别或实例归属,任何损失都可能导致标注信息的破坏。
- 通道与位深:PNG支持多种颜色类型,包括调色板索引色、灰度图(可带alpha通道)以及真彩色RGB或RGBA图像。对于分割掩码,常见的是单通道灰度图(例如8位,像素值直接对应类别ID)或索引色图像(调色板将索引映射到特定颜色,间接表示类别)。其位深可以从每通道1位到16位不等。例如,一个RGBA图像可以有32位(每通道8位)或64位(每通道16位)的位深。
- 考量:无损压缩保证了像素值的精确性,这对于掩码至关重要。虽然对于自然图像,PNG文件通常比JPEG大,但在需要精确像素表达或透明度的场景(如叠加掩码进行可视化)中是首选。
- JPEG (Joint Photographic Experts Group):
- 用途:因其高压缩率,JPEG是数据集中存储自然场景照片最常用的格式。
- 通道与位深:通常是RGB三通道图像。标准的JPEG格式每通道使用8位,即24位真彩色。一些JPEG变种(如JPEG 2000)也支持更高的位深,例如12位。
- 压缩:JPEG采用的是有损压缩算法(基于离散余弦变换DCT)。压缩程度可以调整,以在文件大小和图像质量之间取得平衡。典型的JPEG压缩能在可接受的质量损失下达到10:1的压缩比。
- 考量:由于是有损压缩,JPEG不适合存储那些需要精确像素值的分割掩码,因为压缩过程会引入伪影并改变像素值。但对于大规模的自然图像训练数据,牺牲一定的图像保真度以换取更小的存储空间是完全可以接受的。
图像格式的选择直接影响存储效率和数据保真度。对于分割任务,掩码的精确性是第一位的,因此无损的PNG是标准选择。而对于图像本身,尤其是在大规模数据集中,JPEG的压缩效率使其成为主流。
2.2.2. 标注文件类型
标注信息,作为训练模型的“答案”,其组织和存储方式同样多样。
- JSON (JavaScript Object Notation):
- 结构:一种轻量级的数据交换格式,易于人阅读和编写,也易于机器解析和生成。采用键值对和嵌套结构(对象和数组)来表示数据。
- 用途:COCO数据集的标注就是以JSON格式存储的。它非常适合表示复杂的、具有层级关系的数据。
- CV场景:COCO的多方面标注(如目标检测框、多边形/RLE分割、关键点、图像描述、超类等)都得益于JSON的灵活性。
- XML (eXtensible Markup Language):
- 结构:基于标签的标记语言,具有层级结构,也易于人类阅读。
- 用途:PASCAL VOC数据集的标注格式,以及ImageNet目标检测任务的标注。
- CV场景:非常适合结构化的对象标注,特别是当对象有很多属性时(例如PASCAL VOC中的
<bndbox>
,<xmin>
,<name>
,<pose>
,<truncated>
,<difficult>
等标签)。
- TXT (Plain Text):
- 结构:简单文本文件,通常是逐行存储信息。每行可以代表一个完整的标注或标注的一部分,常使用逗号、空格等作为分隔符。
- 用途:YOLO系列目标检测算法的标注格式(每行通常是:
class_id x_center y_center width height
,坐标和尺寸经过归一化处理);LaSOT数据集的逐帧跟踪真值(ground truth);以及简单的文件路径列表或类别标签列表。 - CV场景:适用于标注结构相对简单,或优先考虑解析便捷性的场景。YOLO格式因其直接对应模型输入而非常流行。
- CSV/TSV (Comma/Tab-Separated Values):
- 结构:表格数据,由行和列组成。第一行通常是表头(列名)。CSV使用逗号分隔,TSV使用制表符分隔。
- 用途:存储元数据、简单的边界框标注(如:文件名,xmin,ymin,xmax,ymax,类别)、图文对(如Conceptual Captions中的
image_url, caption
)。 - CV场景:适合存储扁平化的标注列表或元数据,可以方便地使用Pandas等库加载为数据帧进行处理。
- Parquet:
- 结构:一种列式存储文件格式,为大规模数据分析和处理设计,具有高压缩率和查询效率。
- 用途:LAION数据集的元数据(图片URL、文本描述、相似度得分等)就采用了Parquet格式。
- CV场景:由于其处理海量元数据的高效性,被LAION这类网络规模数据集所采用。特别适合需要频繁查询或只读取部分列的场景。
标注格式的选择往往权衡了标注信息的复杂度、生态系统的兼容性以及解析和存储的效率。JSON(如COCO)为复杂的多模态标注提供了极大的灵活性。XML(如PASCAL VOC)是目标检测等以边界框为核心的任务的传统标准。TXT(如YOLO)则以其简洁高效服务于特定的模型输入需求。而Parquet格式的兴起,则直接反映了处理网络级别超大规模元数据的需求。
表4: CV中常见的图像与标注文件格式
文件类型 | CV中主要用途 | 关键结构特征 | 数据集示例 |
---|---|---|---|
PNG | 分割掩码 | 无损, 支持多通道 (灰度, RGB, RGBA), 位深1-16/通道 | COCO (掩码存储), 许多自定义分割数据集 |
JPEG | 自然图像 | “有损, 通常RGB, 8位/通道” | “ImageNet, COCO (图像主体), 大多数图像分类和检测数据集” |
JSON | 复杂/COCO风格标注 | “键值对, 嵌套结构, 支持多种数据类型” | “COCO (检测, 分割, 关键点, 描述)” |
XML | PASCAL VOC/ImageNet检测标注 | “基于标签, 层级结构” | “PASCAL VOC, ImageNet (检测)” |
TXT | “YOLO标注, LaSOT真值, 简单列表” | “行式, 分隔符分隔 (空格, 逗号)” | “YOLO系列数据集, LaSOT (groundtruth.txt)” |
CSV/TSV | “元数据, 简单标注, 图文对” | “表格结构, 行列式, 表头” | “Conceptual Captions (图文对), 一些自定义检测数据集的简易标注” |
Parquet | 大规模元数据存储 | “列式存储, 高效压缩和查询” | LAION (图文元数据) |
第三部分.数据集的加载使用
数据集的使用可以简单的划分为两种,一种是有封装好的函数接口的,比如coco,imagenet等,一种是需要自定义数据集类来使用的。
3.1 使用封装好的接口加载数据集(以coco为例)
PyTorch的torchvision.datasets模块为许多常用的CV数据集提供了内置的加载类,极大地方便了用户。COCO数据集就是其中一个典型的例子。
你可以把 PyTorch 中的 CocoDetection
看作一个“数据读取器”。你告诉它你的图片在哪里,标注文件(JSON 文件)在哪里,然后它就能帮你一张一张地读取图片和它对应的标注信息。
核心步骤:
-
告诉 PyTorch 图片和标注文件的位置:
IMAGE_DIR
: 存放所有 COCO 图片的文件夹路径 (例如val2017/
或train2017/
)。ANNOTATION_FILE
: 那个.json
标注文件的完整路径 (例如annotations/instances_val2017.json
)。
-
创建一个
CocoDetection
对象:from torchvision.datasets import CocoDetection # 1. 定义你的路径 (你需要修改这些!) IMAGE_DIR = "/path/to/your/coco/val2017" ANNOTATION_FILE = "/path/to/your/coco/annotations/instances_val2017.json" # 2. 创建数据集对象 coco_dataset = CocoDetection(root=IMAGE_DIR, annFile=ANNOTATION_FILE)
执行完上面这行,
coco_dataset
就准备好了。它知道怎么去读取你的数据了。 -
从数据集中取出一张图片和它的标注:
CocoDetection
就像一个列表,你可以用索引 (例如0
,1
,2
…) 来获取里面的数据。# 3. 获取第一个样本 (第0张图片和它的标注) # 如果 coco_dataset 为空,这里会报错,实际使用中需要检查数据集大小 if len(coco_dataset) > 0: image, annotations_for_image = coco_dataset[0] else: print("数据集中没有样本!请检查路径和文件。") exit()
image
: 这通常是一个 PIL (Pillow) 库的图像对象。你可以用它来显示图片,或者之后再转换成 PyTorch Tensor。annotations_for_image
: 这是一个列表。因为一张图片里可能有多个物体,所以这个列表里的每一项都是一个字典,代表图片中一个物体的标注信息。
-
使用获取到的信息:
print(f"成功获取到第一个样本!") print(f"图像类型: {type(image)}") # 应该会显示 <class 'PIL.Image.Image'> print(f"图像尺寸: {image.size}") # 显示 (宽度, 高度) print(f"\n这张图片中检测到了 {len(annotations_for_image)} 个物体。") if annotations_for_image: # 如果至少有一个物体被标注 first_object_annotation = annotations_for_image[0] # 获取第一个物体的标注字典 print("\n第一个物体的标注信息示例:") print(f" - 边界框 (bbox): {first_object_annotation['bbox']}") # [x, y, width, height] print(f" - 类别ID (category_id): {first_object_annotation['category_id']}") # 你还可以获取类别ID对应的类别名称 category_id = first_object_annotation['category_id'] # coco_dataset.coco 是一个 COCO API 对象,可以用它来查询类别信息 category_info = coco_dataset.coco.loadCats(category_id) # 返回一个列表,取第一个元素 category_name = category_info[0]['name'] print(f" - 类别名称: {category_name}") else: print("这张图片中没有标注任何物体。")
完整的示例代码:
```python
import os
from torchvision.datasets import CocoDetection # 导入 COCO 数据集加载器
from PIL import Image # 用于确认图像类型
# --- 用户需要修改的路径 ---
# TODO: 请将下面的路径替换为你本地 COCO 数据集的实际路径
COCO_IMAGE_DIR = "/path/to/your/coco/val2017" # 例如: "/data/coco2017/val2017" 或 "C:/Users/YourName/Desktop/coco/val2017"
COCO_ANNOTATION_FILE = "/path/to/your/coco/annotations/instances_val2017.json" # 例如: "/data/coco2017/annotations/instances_val2017.json" 或 "C:/Users/YourName/Desktop/coco/annotations/instances_val2017.json"
# --- 修改结束 ---
# 检查路径是否存在,给用户提示
if not os.path.exists(COCO_IMAGE_DIR) or not os.path.exists(COCO_ANNOTATION_FILE):
print(f"错误:请确保以下路径正确并存在:")
print(f" 图片文件夹: {COCO_IMAGE_DIR}")
print(f" 标注文件: {COCO_ANNOTATION_FILE}")
print(f"请修改脚本中的 COCO_IMAGE_DIR 和 COCO_ANNOTATION_FILE 变量。")
exit()
print(f"尝试从以下位置加载COCO数据集:")
print(f" 图片文件夹: {COCO_IMAGE_DIR}")
print(f" 标注文件: {COCO_ANNOTATION_FILE}")
try:
# 1. 加载数据集
# root 参数指向图片文件夹
# annFile 参数指向 JSON 标注文件
coco_dataset = CocoDetection(root=COCO_IMAGE_DIR,
annFile=COCO_ANNOTATION_FILE)
print(f"\nCOCO 数据集加载成功!")
print(f"数据集中共有 {len(coco_dataset)} 张图片。\n")
# 2. 获取并使用第一个样本 (图像及其所有标注)
if len(coco_dataset) > 0:
# image 是一个 PIL.Image 对象
# target_annotations_list 是一个包含该图像中所有物体标注的列表,每个标注是一个字典
image, target_annotations_list = coco_dataset[0] # 获取第一个样本
print(f"--- 第一个样本信息 ---")
print(f"图像类型: {type(image)}")
if isinstance(image, Image.Image): # 确保 image 是 PIL Image 对象
print(f"图像模式: {image.mode}, 图像尺寸: {image.size}")
print(f"该图像中标注的物体数量: {len(target_annotations_list)}")
# 打印第一个物体的标注(如果存在)
if target_annotations_list:
first_object_ann = target_annotations_list[0] # 获取列表中的第一个标注字典
print(f"\n 第一个物体的标注详情:")
# 使用 .get() 方法访问字典键,以避免因键不存在而引发 KeyError
print(f" 边界框 (bbox [x,y,width,height]): {first_object_ann.get('bbox')}")
print(f" 类别ID (category_id): {first_object_ann.get('category_id')}")
# 使用 coco_dataset.coco 对象来获取类别名称
# coco_dataset.coco 是一个 pycocotools.coco.COCO 类的实例
category_id = first_object_ann.get('category_id')
if category_id is not None: # 确保 category_id 存在
# coco_dataset.coco.loadCats(category_id) 返回一个包含单个字典的列表
category_info_list = coco_dataset.coco.loadCats(category_id)
if category_info_list: # 确保返回的列表不为空
category_name = category_info_list[0]['name']
print(f" 类别名称: {category_name}")
else:
print(f" 警告: 无法找到类别ID {category_id} 的名称。")
else:
print(f" 警告: 第一个标注中没有 'category_id'。")
else:
print(" 该图像中没有标注的物体。")
else:
print("数据集中没有样本。")
except FileNotFoundError:
print(f"\n错误: 文件未找到。请再次检查 COCO_IMAGE_DIR 和 COCO_ANNOTATION_FILE 的路径。")
print(f" 图片文件夹应指向包含 .jpg 文件的目录,例如 'val2017'。")
print(f" 标注文件应指向 .json 文件,例如 'instances_val2017.json'。")
except ImportError:
print(f"\n错误: 看起来 'pycocotools' 可能没有正确安装。")
print(f"请尝试使用 'pip install pycocotools' 或 'conda install -c conda-forge pycocotools' 进行安装。")
except Exception as e:
print(f"\n加载或处理COCO数据集时发生未知错误: {e}")
print("请检查路径是否正确,以及所有依赖库 ('torch', 'torchvision', 'Pillow', 'pycocotools') 是否已正确安装。")
3.2 自定义图像数据集
场景描述
假设你的图像数据存储结构如下:
/path/to/your/custom_data/ <-- 这是你的数据集根目录 (root_dir)
├── cats/ <-- 类别 “cats” 的子文件夹
│ ├── cat_image_001.jpg
│ ├── cat_image_002.png
│ └── …
├── dogs/ <-- 类别 “dogs” 的子文件夹
│ ├── dog_image_001.jpg
│ ├── dog_image_002.jpeg
│ └── …
└── birds/ <-- 类别 “birds” 的子文件夹
├── bird_image_001.jpg
└── …
我们的目标是创建一个 PyTorch Dataset
类,它能够:
- 遍历这些文件夹,找到所有图像的路径。
- 根据图像所在的子文件夹名称(例如 “cats”, “dogs”, “birds”),为每张图像分配一个整数类别标签。
- 在被调用时(例如,通过索引),加载一张指定的图像并返回图像数据(通常是 Tensor)及其对应的整数标签。
核心步骤与代码结构 (torch.utils.data.Dataset
)
在 PyTorch 中,自定义数据集通常需要继承 torch.utils.data.Dataset
类,并重写以下三个核心方法:
-
__init__(self, root_dir, transform=None)
(构造函数/初始化方法):- 目的: 执行一次性的设置操作。通常在这里完成扫描数据、收集所有图像的文件路径、创建类别到整数标签的映射等工作。
root_dir
(字符串): 数据集的根目录路径,例如"/path/to/your/custom_data/"
。transform
(可调用对象, 可选): 一个函数/转换组合,用于对加载的图像进行预处理(例如,调整大小、裁剪、转换为 Tensor、归一化等)。如果不需要转换,可以为None
。
-
__len__(self)
(长度方法):- 目的: 返回数据集中样本的总数量。PyTorch 的
DataLoader
会使用这个方法来确定一个 epoch 中有多少数据,以及如何有效地进行批处理和打乱。
- 目的: 返回数据集中样本的总数量。PyTorch 的
-
__getitem__(self, idx)
(获取单个样本方法):- 目的: 根据给定的索引
idx
,加载并返回一个数据样本(通常是图像和其对应的标签)。这是数据加载的核心,DataLoader
在构建数据批次时会反复调用这个方法。
- 目的: 根据给定的索引
完整示例代码
import os # 导入os模块,用于处理文件和目录路径
from PIL import Image # 导入Pillow库中的Image模块,用于加载和处理图像
import torch # 导入PyTorch库
from torch.utils.data import Dataset # 从PyTorch的工具包中导入Dataset基类
from torchvision import transforms # 从torchvision库中导入transforms模块,用于图像预处理 (可选)
class CustomImageDataset(Dataset):
def __init__(self, root_dir, transform=None):
"""
自定义图像数据集类的构造函数。
参数:
root_dir (string): 包含所有类别子文件夹的数据集根目录。
例如: "/path/to/your/custom_data/"
transform (callable, optional): 一个可选的转换函数/组合,
用于对加载的图像进行预处理。
默认为 None,表示不进行额外转换。
"""
self.root_dir = root_dir # 数据集根目录路径
self.transform = transform # 图像预处理转换
# 获取根目录下所有子文件夹的名称作为类别名,并进行排序以保证一致性
# os.listdir(root_dir) 返回指定路径下的文件和文件夹列表
# sorted() 对列表进行排序
# 我们只把文件夹名当作类别
self.classes = sorted([d for d in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, d))])
# 创建类别名到整数标签的映射字典 (例如: {'cats': 0, 'dogs': 1, 'birds': 2})
self.class_to_idx = {cls_name: i for i, cls_name in enumerate(self.classes)}
# 创建整数标签到类别名的映射字典 (例如: {0: 'cats', 1: 'dogs', 2: 'birds'})
self.idx_to_class = {i: cls_name for cls_name, i in self.class_to_idx.items()}
self.image_paths = [] # 用于存储数据集中所有图像的完整路径
self.labels = [] # 用于存储每张图像对应的整数标签
# 遍历根目录下的每个类别子文件夹
for class_name in self.classes:
class_path = os.path.join(root_dir, class_name) # 构建当前类别子文件夹的完整路径
# 遍历当前类别文件夹下的所有文件
for img_name in os.listdir(class_path):
# 简单的文件类型检查,确保只处理常见的图像文件格式
# 可以根据需要扩展支持的文件类型
if img_name.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.gif', '.tiff')):
self.image_paths.append(os.path.join(class_path, img_name)) # 将图像路径添加到列表
self.labels.append(self.class_to_idx[class_name]) # 将对应的整数标签添加到列表
# 打印初始化信息
print(f"数据集初始化完成。")
if not self.classes:
print(f"警告: 在 '{root_dir}' 中没有找到任何类别子文件夹。")
else:
print(f"在 '{root_dir}' 中找到 {len(self.classes)} 个类别: {self.classes}")
print(f"类别到索引的映射: {self.class_to_idx}")
if not self.image_paths:
print(f"警告: 在 '{root_dir}' 及其子文件夹中没有找到任何图像文件。")
else:
print(f"总共找到 {len(self.image_paths)} 张图像。")
def __len__(self):
"""
返回数据集中图像的总数量。
这是 Dataset 类必须重写的方法。
"""
return len(self.image_paths) # 返回图像路径列表的长度
def __getitem__(self, idx):
"""
根据给定的索引加载并返回一个数据样本 (图像和其对应的标签)。
这是 Dataset 类必须重写的方法。
参数:
idx (int): 要获取的样本在数据集中的索引。
返回:
tuple: (image, label)
其中 image 是经过转换后的图像 (通常是 PyTorch Tensor),
label 是图像对应的整数标签 (int)。
如果图像加载失败,则 image 可能为 None。
"""
# 1. 根据索引获取图像路径和原始标签
img_path = self.image_paths[idx] # 获取第idx个图像的路径
label = self.labels[idx] # 获取第idx个图像的标签
# 2. 加载图像
# 使用 PIL (Pillow) 库的 Image.open() 方法打开图像文件
# .convert('RGB') 确保图像是3通道RGB格式。
# 这样做可以统一图像格式,即使原始图像是灰度图、带alpha通道的PNG或其他格式。
try:
image = Image.open(img_path).convert('RGB')
except FileNotFoundError:
print(f"错误: 图像文件未在路径 {img_path} 找到。将跳过此图像。")
# 返回 None 会导致 DataLoader 在默认 collate_fn 中出错,
# 更好的处理方式是在 __init__ 中过滤掉不存在的文件,或者使用自定义的 collate_fn。
# 为了简单起见,这里仅打印错误。实际应用中需要更稳健的处理。
return None, torch.tensor(label) # 或者返回一个占位图像
except Exception as e:
print(f"错误: 加载图像 {img_path} 时发生错误: {e}。将跳过此图像。")
return None, torch.tensor(label) # 或者返回一个占位图像
# 3. 应用转换 (如果定义了 transform)
# transform 通常包含一系列图像预处理步骤,如调整大小、裁剪、转换为Tensor、归一化等。
if self.transform:
image = self.transform(image)
return image, label # 返回处理后的图像和其标签
# --- 如何使用这个自定义数据集 (示例代码) ---
if __name__ == '__main__':
# 1. 创建一个假设的数据集目录结构 (仅用于测试目的)
# 在你的项目目录下创建一个名为 "my_custom_images_example" 的文件夹,
# 然后在其中创建 "cats", "dogs", "birds" 子文件夹,并放入一些虚拟图片。
# 实际使用时,你的数据应该已经按照这种结构存在。
# TODO: 将此路径替换为你实际的数据集根目录
# 例如,如果你在当前运行脚本的目录下创建了 my_custom_images_example/cats 等文件夹
# 并且在里面放了图片,那么路径可以是 "./my_custom_images_example"
dataset_root_directory = "./my_custom_images_example"
# 为示例创建虚拟的目录和文件 (如果它们不存在)
# 确保示例代码可以独立运行,即使没有预先准备数据
print(f"检查或创建示例目录: {dataset_root_directory}")
os.makedirs(dataset_root_directory, exist_ok=True) # exist_ok=True 避免目录已存在时报错
example_classes = ["cats", "dogs", "birds"]
image_created_count = 0
for cls_name in example_classes:
class_dir = os.path.join(dataset_root_directory, cls_name)
os.makedirs(class_dir, exist_ok=True)
# 创建一些虚拟的图片文件 (在实际应用中你会用真实的图片)
# 为了避免每次运行时都创建,可以简单检查一下文件是否存在
if not os.listdir(class_dir): # 如果文件夹为空
try:
if cls_name == "cats":
Image.new('RGB', (60, 30), color = 'red').save(os.path.join(class_dir, "dummy_cat1.jpg"))
Image.new('RGB', (50, 40), color = 'green').save(os.path.join(class_dir, "dummy_cat2.png"))
image_created_count += 2
elif cls_name == "dogs":
Image.new('RGB', (60, 30), color = 'blue').save(os.path.join(class_dir, "dummy_dog1.jpeg"))
image_created_count += 1
elif cls_name == "birds":
Image.new('RGB', (70, 70), color = 'yellow').save(os.path.join(class_dir, "dummy_bird1.jpg"))
image_created_count += 1
except Exception as e:
print(f"警告: 创建虚拟图片失败 (可能缺少 Pillow 库,请尝试 `pip install Pillow`): {e}")
if image_created_count > 0:
print(f"提示: 已在 {dataset_root_directory} 中创建了 {image_created_count} 个虚拟图片文件。")
else:
print(f"提示: 示例目录 {dataset_root_directory} 中已存在图片或未创建新图片。")
# 2. 定义一个简单的图像转换 (可选,但对于模型训练通常是必要的)
# 这里我们将图像:
# - 调整到固定大小 (128x128像素)
# - 转换为 PyTorch Tensor
# - ToTensor() 还会将像素值从 [0, 255] (PIL Image) 缩放到 [0.0, 1.0] (Tensor)
simple_transform = transforms.Compose([
transforms.Resize((128, 128)), # 将图像调整到 128x128 像素
transforms.ToTensor() # 将 PIL Image 或 numpy.ndarray 转换为 FloatTensor,
# 并将像素值从 [0, 255] 缩放到 [0.0, 1.0]
])
print(f"\n尝试从以下路径加载自定义数据集: {dataset_root_directory}")
# 3. 实例化自定义数据集
custom_dataset = None # 初始化为 None
try:
custom_dataset = CustomImageDataset(root_dir=dataset_root_directory, transform=simple_transform)
except Exception as e:
print(f"错误: 实例化 CustomImageDataset 失败: {e}")
print(f"请确保路径 '{dataset_root_directory}' 正确,并且其下有类别子文件夹以及有效的图像文件。")
# custom_dataset 保持为 None
if custom_dataset and len(custom_dataset) > 0:
# 4. 验证数据集是否成功加载并包含数据
print(f"\n自定义数据集加载成功!")
print(f"数据集中共有 {len(custom_dataset)} 个样本。") # 调用 __len__()
# 获取并打印第一个样本的信息
# 这里会调用 custom_dataset.__getitem__(0)
first_image, first_label = custom_dataset[0]
print(f"\n--- 第一个样本信息 ---")
if first_image is not None: # 检查图像是否成功加载
print(f"图像类型: {type(first_image)}") # 如果应用了 ToTensor,这里应该是 torch.Tensor
if isinstance(first_image, torch.Tensor):
print(f"图像形状 (通道数, 高度, 宽度): {first_image.shape}") # C, H, W
print(f"标签 (整数): {first_label}")
# 确保 first_label 是有效的键
if first_label in custom_dataset.idx_to_class:
print(f"标签对应的类别名称: {custom_dataset.idx_to_class[first_label]}")
else:
print(f"错误: 标签 {first_label} 在 idx_to_class 映射中未找到。")
else:
print("第一个样本未能成功加载 (图像为 None)。对应的标签是: {first_label}")
# (可选) 获取并打印第二个样本的信息 (如果存在且数据集长度大于1)
if len(custom_dataset) > 1:
second_image, second_label = custom_dataset[1] # 调用 __getitem__(1)
print(f"\n--- 第二个样本信息 (如果存在) ---")
if second_image is not None: # 检查图像是否成功加载
if isinstance(second_image, torch.Tensor):
print(f"图像形状: {second_image.shape}")
else:
print(f"图像类型: {type(second_image)}")
print(f"标签: {second_label} (类别: {custom_dataset.idx_to_class.get(second_label, '未知标签')})")
else:
print(f"第二个样本未能成功加载 (图像为 None)。对应的标签是: {second_label}")
# (可选) 展示如何与 PyTorch DataLoader 一起使用
# DataLoader 提供了数据批处理、打乱、多进程加载等功能,对于高效训练非常重要。
from torch.utils.data import DataLoader
# 定义一个 collate_fn 来处理可能的 None 值 (如果 __getitem__ 返回 None)
def collate_fn_skip_none(batch):
# 过滤掉那些图像为 None 的样本
batch = [item for item in batch if item[0] is not None]
if not batch: # 如果过滤后批次为空
return torch.empty(0), torch.empty(0) # 返回空张量或根据需要处理
return torch.utils.data.dataloader.default_collate(batch)
try:
# 尝试创建一个 DataLoader
# batch_size=2 表示每个批次加载2个样本
# shuffle=True 表示在每个 epoch 开始时打乱数据顺序
# num_workers > 0 可以开启多进程加载数据,提高效率 (在Windows上num_workers>0可能需要将代码放在 __main__ 保护中)
# 使用自定义的 collate_fn 来处理潜在的 None 值
dataloader = DataLoader(custom_dataset, batch_size=2, shuffle=True, num_workers=0, collate_fn=collate_fn_skip_none)
print("\n--- DataLoader 迭代一个批次示例 ---")
# 从 dataloader 中获取一个批次的数据
# images_batch 会是一个形状像 (batch_size, channels, height, width) 的 Tensor
# labels_batch 会是一个包含 batch_size 个标签的 Tensor
batch_count = 0
for images_batch, labels_batch in dataloader:
if images_batch.numel() == 0: # 检查是否是空批次
print("获取到一个空批次,可能所有样本都加载失败。")
continue
print(f"批次图像形状: {images_batch.shape}")
print(f"批次标签: {labels_batch}")
# 获取批次中第一个标签对应的类别名
if len(labels_batch) > 0:
first_label_in_batch = labels_batch[0].item() # .item() 从0维Tensor中获取Python数值
print(f"批次中第一个标签的类别名: {custom_dataset.idx_to_class.get(first_label_in_batch, '未知标签')}")
batch_count += 1
if batch_count >=1 : # 只演示一个或两个有效批次
break
if batch_count == 0:
print("未能从 DataLoader 中获取任何有效批次。")
except Exception as e:
print(f"错误: 创建或使用 DataLoader 时发生错误: {e}")
elif custom_dataset and len(custom_dataset) == 0:
print("\n自定义数据集已成功实例化,但其中没有找到任何图像样本。")
print(f"请检查 '{dataset_root_directory}' 目录下的子文件夹是否包含有效的图像文件 (例如 .png, .jpg, .jpeg 等)。")
print(f"并且确保子文件夹的名称被正确识别为类别。")
else:
# custom_dataset 为 None,表示在 __init__ 阶段就出错了或未成功创建
print("\n自定义数据集未能成功实例化或在初始化过程中发生错误。请检查之前的错误信息。")