pytorch制作数据集不像TensorFlow那么复杂,只需要交单的把数据集加载进来,继承Dataset类和dataloader类
继承Dataset类
在使用时只需要继承该类,并重写__len__()和__getitem()__函数,即可以方便地进行数据集的迭代。
from torch.utils.data import Dataset
class my_data(Dataset):
def __init__(self, image_path, annotation_path, transform=None):
"""初始化,读取数据集"""
pass
def __len__(self):
"""获取数据集的总大小"""
return
def __getitem__(self, id):
"""对于指定的id,读取该数据并返回"""
idx = id
return
继承dataloader类
经过Dataset类封装,已经可以获取每一个样本,但是仍然无法进行批量处理、随机选取等操作,因此还需要torch.utils.data.Dataloader类进
一步进行封装
# 使用Dataloader进一步封装Dataset
dataset = my_data()
dataloader = Dataloader(dataset, batch_size=4,shuffle=True,num_workers=4)