1.说明
pytorch中用dataset来对单个样本进行打包features和labels,得到datasets=(features,labels);用DataLoader来包装Datasets,使得可以每一个批次batchsize打包起来,得到一个批量大小的datasets
torch.utils.data.Dataset
:打包(features_i,labels_i)得到datasettorch.utils.data.DataLoader
:打包dataset_i 得到 DataLoader
2. 从pytorch中得到datasets
import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
# 训练的datasets
training_data = datasets.FashionMNIST(
# root:表示地址
root="data",
# True:表示是否是训练集
train=True,
download=True,
# transform:表示的是将图片转换张量,这里可以对图片进行相关预处理
transform=ToTensor()
)
# 测试的datasets
test_data = datasets.FashionMNIST(
root="data",
train=False,
download=True,
transform=ToTensor()
)
3. 自定义datasets
自定义dataset需要满足4个条件
- 继承自官方的datasets类
class CustomImageDataset(Dataset):
- 覆写初始化函数
__init__
当实例化Dataset对象时,__init__函数运行一次。我们初始化包含图像、注释文件和两个转换的目录
def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
self.img_labels = pd.read_csv(annotations_file)
self.img_dir = img_dir
self.transform = transform
self.target_transform = target_transform
(1)annotations_file
:表示的是labels的csv格式文件名
(2)img_dir
:表示的是features图片目录
(3)transform
:表示的是对图片features进行预处理
(4)target_transform
:表示的是对标签labels进行预处理
- 覆写长度
__len__
函数
__len__函数返回数据集中的样本数量
def __len__(self):
return len(self.img_labels)
- 覆写
__getitem__
函数
作用:根据给定的index序号从datasets中返回特征features和标签
def __getitem__(self, idx):
img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
image = read_image(img_path)
label = self.img_labels.iloc[idx, 1]
if self.transform:
image = self.transform(image)
if self.target_transform:
label = self.target_transform(label)
return image, label
img_path
:图片的位置label
:根据idx得到对应的标签labelself.transform
:将对应的图片进行转换self.target_transform
:将标签转换预处理
4. 定义DataLoader
我们定义好一个样本的datasets,我们需要将多个datasets转换成一个批量大小的DataLoader;我们经常需要将一个批次里面的datasets进行打乱处理再训练,所以我们需要定义DataLoader;
from torch.utils.data import DataLoader
train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)
training_data
: 训练集的datasetstest_data
:测试集的datasetsbatch_size
:批量大小,需要多少个datasetsshuffle
:一个批量大小中的datasets是否需要打乱,为了提高模型的鲁棒性
4.1 DataLoader
在文件dataloader.py中有定义
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
batch_sampler=None, num_workers=0, collate_fn=None,
pin_memory=False, drop_last=False, timeout=0,
worker_init_fn=None, *, prefetch_factor=2,
persistent_workers=False)
dataset
:传入的datasetsbatch_size
: 定义批量大小,默认为1shuffle
:是否打乱一个批量大小里面的datasetssampler
: 采样batch_sampler
:批量采样num_workers
:为数据使用多少子流程;可以给程序启动多进程处理collate_fn
: 用于如何取样本,可以自己定义如何对样本的取出处理pin_memory
:将数据固定到GPU上drop_last
:如果为真,删除最后一个不完整的batch