在torch中提供了数据集的基类torch.utils.data.Dataset,继承这个基类,我们能够非常快速的实现数据的加载。
源码如下:
class Dataset(Dataset[Tuple[Tensor, ...]]):
def __init__(self, *tensors: Tensor) -> None:
assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors), "Size mismatch between tensors"
self.tensors = tensors
def __getitem__(self, index):
return tuple(tensor[index] for tensor in self.tensors)
def __len__(self):
return self.tensors[0].size(0)
我们只需要继承这个基类就可以了,
首先导入我们需要的包
import torch
from torch.utils.data import Dataset #导入Dataset
然后在网上找一个数据集,http://archive.ics.uci.edu/ml/datasets/SMS+Spam+Collection,在这个网站里面有一个短信数据集,几千条短信,里面包含正常短信和几百条骚扰短信。下载下来放到合适的位置,并且记住位置。这个位置需要自己放到想要放的地方。
#数据的位置,这个位置需要自己放到想要放的地方
data_path=r'C:\Users\Administrator\Desktop\smsspamcollection\SMSSpamCollection1'
继承父类Dataset
#继承父类Dataset
class Mydata(Dataset):
def __init__(self):
self.lines=open(data_path, encoding='utf-8').readlines() #打开数据并读取
def __getitem__(self,index): #获取索引对应位置的一条数据
return self.lines[index].strip()
def __len__(self): #返回数据的数量
return len(self.lines)
对其进行实例化,然后打印输出。
my_datdaset=Mydata() #实例化
print(my_datdaset.lines[2]) #输出第二个数据
OK
上述方法仅仅能够进行数据的读取,如果还需要实现批处理数据,打乱数据,使用多线程加载数据,可以使用Pytorch中的torch.utils.data.DataLoader。
代码如下:
import torch
from torch.utils.data import Dataset #导入Dataset
from torch.utils.data import DataLoader #导入DataLoader
#数据的位置,这个位置需要自己放到想要放的地方
data_path=r'C:\Users\Administrator\Desktop\smsspamcollection\SMSSpamCollection1'
#继承父类Dataset
class Mydata(Dataset):
def __init__(self):
self.lines=open(data_path, encoding='utf-8').readlines() #打开数据并读取
def __getitem__(self,index): #获取索引对应位置的一条数据
return self.lines[index].strip()
def __len__(self): #返回数据的数量
return len(self.lines)
my_datdaset=Mydata() #实例化
#使用DataLoader
data_loader=DataLoader(dataset=my_datdaset,batch_size=2,shuffle=True,num_workers=2)
#print(my_datdaset.lines[2]) #输出第二个数据
#if __name__ == '__main__'的意思是:当.py文件被直接运行时,if __name__ == '__main__'之下的代码块将被运行;当.py文件以模块形式被导入时,
#if __name__ == '__main__'之下的代码块不被运行。
if __name__ == '__main__':
for i in data_loader:
print(i) #随机输出一个batch,一个batch里面包含两个样本
break #只循环一次就退出,不需要打印太多