coding=utf-8
“”"
author:lei
function:
“”"
import torch
from torch.utils.data import Dataset, DataLoader
import math
data_path = “./data/SMSSpamCollection”
完成数据集类
class MyDataset(Dataset):
def init(self):
self.lines = open(data_path, encoding=“utf8”).readlines()
def __getitem__(self, index):
# 获取索引对应位置的一条数据
cur_line = self.lines[index].strip()
label = cur_line[:4].strip()
content = cur_line[4:].strip()
return label, content
def __len__(self):
# 返回数据的总数量
return len(self.lines)
my_dataset = MyDataset()
data_loader = DataLoader(dataset=my_dataset, batch_size=2, shuffle=True)
if name == ‘main’:
# my_dataset = MyDataset()
# print(my_dataset[1000])
# print(len(my_dataset))
print(data_loader) # 可迭代对象
for i in data_loader:
print(i)
print(len(my_dataset)) # 5574
print(len(data_loader)) # 2787
print(math.ceil(len(my_dataset) / 7)) # 向上取整的操作