pytorch 提供了一种数据处理的方式,能够生成mini-batch的数据,在训练和测试的时候进行多线程处理,加快准备数据的速度。这个函数工具是
torch.utils.data import Dataset, DataLoader
其中Dataset是我们定义自己的多线程数据处理框架的父类,我们定义的框架要继承这个类
下面简单定义数据准备的框架吧!!!
from torch.utils.data import Dataset,DataLoader
class MyDataset(Dataset):
def __init__(self, filepath, transform=None,keys = None, targe