1. Dataset
在PyTorch中,Dataset是一个基础类,用于定义数据集的结构和数据的加载方式。通过继承Dataset类并实现特定的方法,可以创建自定义的数据集,这些方法通常包括获取数据集的大小、获取单个样本等。
对于任何数据集,如果它提供了从键到数据样本的映射,在对数据集进行加载时,都应该对该类进行子类化。
Dataset 作为一个抽象类,在继承使用该类时,应该实现两个方法,定义自己的数据集。分别是:
- __getitem__(self, index): 根据的索引 index,返回对应的数据;索引可以是一个整数,表示按顺序获取样本,也可以是其他方式,如通过文件名获取样本等。
- __len__: 返回数据集中样本的总数。
2. 自定义Dataset
import torch
from torch.utils.data import Dataset
# 创建自己的 数据集 Dataset
class MyDataset(Dataset):
# 对数据集 进行初始化 ,包括 数据与标签
def __init__(self, data, labels):
self.data = data
self.labels = labels
# 返回数据集的长度
def __len__(self):
return len(self.data)
# 根据索引返回对应的 value 值
def __getitem__(self, idx):
return self.data[idx], self.labels[idx]
# 假设data和labels是已经加载或生成的数据和标签
data = [1, 2, 3, 4, 5] # 示例数据
labels = ['春', '夏', '春', '夏','夏'] # 示例标签
# 初始化数据
dataset = MyDataset(data, labels)
# 根据 索引,返回对应的数据与标签
print(dataset[1]) # (2, '夏')
# 返回数据集的长度
print(len(dataset)) # 5
# 查看数据集中的所有数据
print(dataset.__dict__) # {'data': [1, 2, 3, 4, 5], 'labels': ['春', '夏', '春', '夏', '夏']}
此外,PyTorch还提供了多种内置的数据集类,如MNIST、CIFAR等,这些内置数据集可以直接使用,无需自定义。
读取 数据集练习:
项目的目录格式:
import os
from PIL import Image
class MyDataset1(Dataset):
def __init__(self, root_dir, img_dir, label_dir):
self.root_dir = root_dir
self.img_dir = img_dir
self.label_dir = label_dir
# 由于 linux 与 Windows 的斜杠不同,因此采用该方法 os
self.img_path = os.path.join(self.root_dir,self.img_dir)
self.label_path = os.path.join(self.root_dir, self.label_dir)
# 列出所有的图片名称
self.img_name = os.listdir(self.img_path)
# 列出所有的标签名称
self.label_name = os.listdir(self.label_path)
def __len__(self):
return len(self.img_name)
def __getitem__(self, idx):
# 获取 图片与标签的相对路径
img_relative_path = os.path.join(self.img_path,self.img_name[idx])
lab_relative_path = os.path.join(self.label_path,self.label_name[idx])
with open(lab_relative_path,"r") as f:
labels = f.read()
img = Image.open(img_relative_path)
return img, labels
root_dir = "hymenoptera_data/train"
ants_img_dir = "ants_img"
bees_img_dir = "bees_img"
ants_lab_dir = "ants_lab"
bees_lab_dir = "bees_lab"
ants_dataset = MyDataset1(root_dir,ants_img_dir,ants_lab_dir)
bees_dataset = MyDataset1(root_dir,bees_img_dir,bees_lab_dir)
# 获取所有的数据集
train_dataset = ants_dataset + bees_dataset
# 读取数据集
print(ants_dataset[0])
print(len(ants_dataset))
方式:
1. 首先 通过 init() 函数, 传入对应的数据集的各个文件夹名称;在 init() 函数中 通过 os.path.join() 函数,对文件夹名称进行拼接,获取所需要的 相对地址;最后可以通过 os.listdir() 函数, 获取文件夹中的文件名称列表。
2. 在 getitem() 函数中,通过 index 索引, 找到对应的 文件名称,拼接成该文件的 相对地址,并通过对应的函数进行打开;最后将打开的结果 通过 return 返回。
3. torchversion
介绍
torchvision是pytorch的一个图形库,专注于计算机视觉任务,包括预训练模型、数据加载和预处理,模型构建和评估等。
torchvision的设计初衷是为了简化使用PyTorch进行计算机视觉任务的开发和研究,减少重复性代码的编写,提高开发效率。
使用
import torchvision
from torch.utils.tensorboard import SummaryWriter
# Compose 对图像进行处理
dataset_transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor()
])
train_set = torchvision.datasets.CIFAR10(root="./dataset", train=True, transform=dataset_transform,download=True)
test_set = torchvision.datasets.CIFAR10(root="./dataset", train=False, transform=dataset_transform,download=True)
print(train_set[0]) #(<PIL.Image.Image image mode=RGB size=32x32 at 0x2B0BD213170>, 6)
print(train_set.classes) # ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
img, target = train_set[0]
print(img) # <PIL.Image.Image image mode=RGB size=32x32 at 0x18BB86AF200>
print(target) # 6 查看 该图片的 标签 是frog
print(train_set.classes[target]) # frog
write = SummaryWriter("log_1")
for i in range(10):
img, target = train_set[i]
write.add_image("train_set",img,i)
write.close()
torchvision.datasets.CIFAR10是PyTorch中提供的一个数据集类,用于加载CIFAR-10数据集。CIFAR-10数据集是一个常用的图像分类数据集,包含10个类别的60000张32x32的彩色图像,其中50000张用于训练,10000张用于测试。这个数据集常用于训练图像分类模型。
参数:
- root:数据集的存储路径。如果设置为下载模式(download=True),则会在该路径下保存数据集;如果数据集已经存在,则可以直接从该路径加载。
- train:一个布尔值,表示是否加载训练数据集。如果为True,则加载训练数据;如果为False,则加载测试数据。
- transform:对数据集中的图像进行的一些操作,如归一化、随机裁剪、数据增强等。这是一个可选参数,可以根据需要进行设置。
- target_transform:对数据集中的标签进行的一些操作。这也是一个可选参数,用于对标签进行预处理。
- download:一个布尔值,表示是否需要从Internet下载数据集到指定的root目录。默认为False,即如果数据集已经存在,则不会重新下载。
class CIFAR10(VisionDataset) --> class VisionDataset(data.Dataset) :
可见 CIFAR10 最终继承自 Datase;
其内部也实现了 dataset 的 方法:
- __getitem__ 方法
- __len__ 方法
3. DataLoader
主要用于简化深度学习中的数据加载和预处理过程。它是PyTorch框架中的一个类,能够将数据集(通常是一个Dataset对象)打包成一个可迭代的对象,方便在训练过程中逐批次读取数据。DataLoader提供了多种功能,包括数据的随机打乱、并行加载、多线程加载等,以优化数据处理的效率和灵活性。此外,DataLoader还支持数据清洗和灵活的数据转换,如字段的映射、计算、替换等操作,进一步降低了数据处理的难度。处理完数据后,DataLoader还支持将数据导出为多种格式,如Excel、CSV、数据库等,满足不同用户的需求。
- dataset:这是加载的数据集对象,数据类型为dataset。它指定了DataLoader将要处理的数据来源。
- batch_size:定义了每个batch的大小,即数据被切分的份数。数据类型为int。这个参数决定了每次从数据集中取出的数据量。
- shuffle:决定是否打乱数据的顺序,默认设置为False。数据类型为bool。启用shuffle可以在每个epoch开始时打乱数据的顺序,有助于模型的泛化能力。
- drop_last:当数据集的大小不能被batch_size整除时,决定是否丢弃最后一个不完整的batch。数据类型为bool。如果设置为True,最后一个不完整的batch将被丢弃;如果设置为False,则会保留所有数据,包括不完整的batch。
- num_workers:表示在数据加载时使用的子进程数。在Windows上设置为0,而在Linux上可以设置大于0的值。数据类型为int。这个参数影响了数据加载的速度,通过增加工作进程数可以并行加载数据,提高效率。
- collate_fn:这是一个callable对象,用于将batch_size样本整理成一个batch样本,便于批量处理训练。它定义了如何将多个样本组合成一个batch的规则。
- sampler:这是一个“采样器”,表示从样本中如何取样,提供整个数据集的随机访问的索引列表。数据类型为Sampler。它允许更灵活地控制数据的取样方式。
- Pin_memory:内存寄存标志,默认为False。在数据返回前,决定是否将数据复制到内存中(CPU/GPU)。数据类型为bool。这可以影响数据加载的速度和效率。
- Timeout:设置数据读取的超时时间,超过该值还未读取到数据就会报错。数据类型为numeric。这保证了DataLoader在等待数据时的超时机制。
- Worker_init_fn:这是一个callable对象,用于初始化工作进程。它允许在每个工作进程开始时执行特定的初始化操作。
示例代码:
import torchvision
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader
# 测试集
train_data = torchvision.datasets.CIFAR10(root="./dataset",train=True,transform=torchvision.transforms.ToTensor(),download=True)
train_loader = DataLoader(dataset=train_data,batch_size=4,shuffle=True,num_workers=0 ,drop_last=False)
# 测试数据集
img , target = train_data[0]
print(img.shape)
print(target)
write = SummaryWriter("log_2")
write.add_image("Image2",img, dataformats="CHW")
step=0
for data in train_loader:
imgs, target = data
print(img.shape)
# print(target)
write.add_image("Image4",imgs,step, dataformats="NCHW")
step = step +1
write.close()
dataformats 中的 NCHW的数据排列方式: