我们在学习Pytorch进行文本处理时,所使用的数据集基本上都为官方提供的处理好的,调用torchtext中的相应函数即可实现对数据的处理。那么当我们需要加载自己的数据集时该怎么办呢,本文将以txt文件为例讲解一下如何加载。
我们的txt文件包含852471行,每一行如图所示为一句话
我们将使用torch.utils.data中包含的相关类,将该文件分割成训练集和验证集,并生成迭代器。
1、导入相关类
import os
from torch.utils.data import Dataset, random_split, DataLoader
torch.utils.data.Dataset: 一个抽象类, 所有其他类的数据集类都应该是它的子类。而且其子类必须重载两个重要的函数:len(提供数据集的大小)、getitem(支持整数索引)。
torch.utils.data.random_split(dataset, lengths): 按照给定的长度将数据集划分成没有重叠的新数据集组合。
torch.utils.data.DataLoader: 数据加载器。组合了一个数据集和采样器,并提供关于数据的迭代器。
2、定义我们自己的dataset类
class MyDataset(Dataset):
def __init__(self, instances):
self.instances = instances
//数据集的样本总数
def __len__(self):
return