数据
数据收集–>img,label
数据划分–>train,valid,test(详细见:https://blog.csdn.net/wyyyyyyfff/article/details/104381429)
数据读取–>dataloader–>sampler(index生成索引,样本序号),dataset(根据索引读取img,label)
数据预处理–>transforms
DataLoader
DataLoader是Pytorch中用来处理模型输入数据的一个工具类。通过使用DataLoader,我们可以方便地对数据进行相关操作,比如我们可以很方便地设置batch_size,对于每一个epoch是否随机打乱数据,是否使用多线程等等。
torch.utils.data.DataLoader(dataset,
batch_size=1,
shuffle=False,
num_works=0,
drop_last=False)
功能:构建可迭代的数据装载器
dataset:Dataset类,决定数据从哪读取以及如何读取
batch_size:批大小
shuffle:每个epoch是否乱序
num_works:是否多进程读取数据
drop_last:当样本数不能被batchsize整除时,是否舍弃最后一批数据
关于DataLoader,需要先了解batchsize和epoch等输入数据的相关概念,以及python中类的基本知识比如继承和函数复写。
epoch:所有训练样本都已经输入到模型中,称为一个epoch
iteration:一批样本输入到模型中,称为一个iteration
batchsize:批大小,决定一个epoch有多少iteration
样本总数:87, batchsize:8
drop_last=True–>1 epoch=10 iteration
drop_last=False–>1 epoch=11 iteration
DataLoader的基本使用流程
1.首先会将原始数据加载到DataLoader中去,如果需要shuffle的话,会对数据进行随机打乱操作,这样能够输入顺序对于数据的影响。
2.再使用一个迭代器来按照设置好的batch大小来迭代输出shuffle之后的数据。
Tips: 通过使用迭代器能够有效地降低内存的损耗,会在需要使用的时候才将数据加载到内存中去。
Dataset 解决数据从哪里读取以及如何读取
功能:Dataset抽象类,所有自定义的Dataset需要继承它,并且复写_getitem_()
getitem:接收一个索引,返回一个样本
使用Dataset来创建自己的数据类:
- 继承torch.utils.data.Dataset这个类
- 复写__getitem__ 和 __ len__ 这两个方法
- 如下图接收一个index,返回样本以及标签(如何读取样本,用户编写getitem)
torch.utils.data.Dataset
class Dataset(object):
def _getitem_(self, index):
raise NotImplementedError
def __len__(self):
raise NotImplementedError
def _add_(self, other):
return ConcatDataset([self, other])
例子:
class MyDataset(Dataset):
""" my dataset."""
# Initialize your data, download, etc.
def __init__(self):
# 读取csv文件中的数据
xy = np.loadtxt('.csv', delimiter=',', dtype=np.float32)
self.len = xy.shape[0]
# 除去最后一列为数据位,存在x_data中
self.x_data = torch.from_numpy(xy[:, 0:-1])
# 最后一列为标签为,存在y_data中
self.y_data = torch.from_numpy(xy[:, [-1]])
def __getitem__(self, index):
# 根据索引返回数据和对应的标签
return self.x_data[index], self.y_data[index]
def __len__(self):
# 返回文件数据的数目
return self.len
数据读取
1.读哪些数据 sampler输出的index
2.从哪读数据 Dataset中的data_dir
3.怎么读数据 Dataset中的getitem
os.path.join(从哪里读数据,数据路径)
import os
split_dir = os.path.join("..", "..", "data", "rmb_split")
train_dir = os.path.join(split_dir, "train")
valid_dir = os.path.join(split_dir, "valid")
print(split_dir)
print(train_dir)
print(valid_dir)
…\data\rmb_split
…\data\rmb_split\train
…\data\rmb_split\valid
import shutil
shutil – Utility functions for copying and archiving files and directory trees.
(用于复制和存档文件和目录树的实用功能。)
详细见:https://blog.csdn.net/wyyyyyyfff/article/details/104381429
from PIL import Image
详细:https://www.cnblogs.com/lyrichu/p/9124504.html
https://blog.csdn.net/Li_qf/article/details/84925027
https://blog.csdn.net/zhangziju/article/details/79123275?utm_source=distribute.pc_relevant.none-task