# 数据集类的使用
# http://archive.ics.uci.edu/ml/datasets/SMS+Spam+Collection
import torch
from torch.utils.data import Dataset, DataLoader
import math
data_path = r"D:\下载内容\smsspamcollection\SMSSpamCollection"
# 完成数据集类
class MyDataset(Dataset):
def __init__(self):
self.lines = open(data_path, encoding='gb18030', errors='ignore').readlines()
def __getitem__(self, index):
# 获取索引对应位置的一条数据
cur_line = self.lines[index].strip()
label = cur_line[:4].strip()
content = cur_line[4:].strip()
return label, content
def __len__(self):
return len(self.lines)
my_dataset = MyDataset()
data_loader = DataLoader(dataset=my_dataset, batch_size=2, shuffle=True,drop_last=True)
if __name__ == '__main__':
# print(my_dataset[1000])
# print(len(my_dataset))
for i in data_loader:
print(i)
数据集类和数据加载器类
最新推荐文章于 2024-01-06 15:55:21 发布