下载
官网:ImageNet
百度网盘: https://pan.baidu.com/s/1Tmh9-XWUvDwexf-00P5IyQ?pwd=c86x 提取码: c86x
官网需要教育邮箱注册登录才能下载,而且下载速度很慢。建议使用百度网盘链接
下载后会得到下面三个压缩包
- ILSVRC2012_devkit_t12.tar.gz:工具包
- ILSVRC2012_img_train.tar:训练集,140G
- ILSVRC2012_img_val.tar:验证集,6.4G
处理
调用官方的方法进行处理
from torchvision.datasets.imagenet import parse_devkit_archive, parse_train_archive, parse_val_archive
root = "/path/to/folder/of/archives"
parse_devkit_archive(root)
parse_train_archive(root)
parse_val_archive(root)
使用这段代码需要将root路径设为三个压缩包的根目录,运行上面的代码,会自动根据工具包中的信息编织数据。其中train压缩包需要先解压。
tar -xvf /root/autodl-tmp/datasets/ImageNet2012/ILSVRC2012_img_train.tar -C /root/autodl-tmp/datasets/ImageNet2012/
最终得到如下的文件结构:
/imagenet/
|----train
|----n01440764
|----...
|----val
|----n01440764
|----...
使用
import torchvision
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
# 定义数据变换
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# 加载训练集
train_dataset = datasets.ImageNet(root='/path/to/imagenet', split='train', transform=transform)
# 创建数据加载器
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
# 迭代数据集
for inputs, labels in train_loader:
print(inputs.shape) # 打印输入图像的形状
print(labels.shape) # 打印标签的形状
break # 仅迭代一次,查看数据的输出
常用属性:
classes
- 包含所有类别标签的名称的列表。这个列表的长度是 1000,涵盖了 ImageNet 数据集的 1000 个类别。你可以通过索引访问具体类别的名称。
class_to_idx
- 类别名称到类别索引的映射。它是一个字典,键是类别的名称,值是类别的索引。
imgs
- 返回数据集的所有图像路径和标签。它是一个包含元组的列表,每个元组包含图像的路径和该图像的类别标签。
samples
- 与
imgs
类似,包含了图像文件路径和标签对。该属性通常在数据集的加载时由类自动填充。
主要方法:
__getitem__(index)
-
说明:这个方法是
ImageNet
类的核心方法,它用于获取给定index
的图像和对应的标签。 -
返回:返回一个元组
(image, label)
,其中image
是经过处理后的图像,label
是图像对应的类别标签。
__len__()
-
说明:返回数据集中的图像数量。
-
返回:一个整数,表示数据集中的样本数。
extra_repr()
-
说明:返回数据集的描述信息。该方法会显示数据集的一些信息,例如数据集的根目录、当前使用的分割(
train
或val
)、类别数量等。