import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
class MyDataset(Dataset):
def __init__(self):
pass
#根据index获取数据
def __getitem__(self,index):
pass
#获取数据集长度
def __len__(self):
pass
dataset = MyDataset()
train_loader = DataLoader(dataset=dataset,
batch_size=32,
shuffle=True,
num_works=2)
#num_works为后期多线程个数,shuffle为是否打乱