PyTorch数据准备与预处理
.py源文件的结构
- 数据准备与预处理: dataset.py
- 模型:model.py
- 训练规则:train.py
- 测试(benchmark + predict):test.py
dataset.py的功能
统一将图像返回成torch能处理的[original_iamges.tensor,label.tensor]
用来加载数据的类torch.utils.data.DataLoader()
torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False)
重点关注四个参数:
batch_size: 批处理数目
shuffle: 是否每个epoch都打乱
workers: 载入数据的线程数
dataset: 是经过变换的自己的数据集(即:一个继承了torch.utils.data.Dataset
类的子类的实例),[original_iamges.tensor,label.tensor]
之类的,定义的“dataset.py”就是产生这个dataset的。然后在train.py中调用。
导入自己的dataset
class UAVDataSet(torch.ultis.data.Dataset)
-
继承了
torch.utils.data.Dataset
这个(抽象)类,我们看看这个类在中文文档中介绍: -
所有其他数据集都应该进行子类化。所有子类应该重载
__len__
和__getitem__
,前者提供了数据集的大小,后者支持整数索引,范围从0
到len(self)
。当然还有个初始化__init__()
。 -
类 = 属性+方法(变量 + 函数),
__init__()
就是定义自己的属性。
数据子类的基础框架
如上述,必须要重载的是__getitem__()
和__len__()
。
__len__()
:len(dataset)
返回数据集的大小。__getitem__()
:实现数据集的下标索引,使用dataset[i]
来得到第i个样本(图像和标记)。
import torch.utils.data as data
import torch
from torchvision import transforms
class MyTrainData(torch.utils.data.Dataset) #子类化
def __init__(self, root, transform=None, train=True): #第一步初始化各个变量
self.root = root
self.train = train
def