1. Pytorch中DataSet的使用方法
1.1 DataSet加载数据的方法
-
DataSet是Pytorch中用来表示数据集的一个抽象类,在torch中提供了数据集的基类torch.utils.data.Dataset,继承这个基类,我们能够快速地实现对数据的加载**.**
__len__
:返回数据集大小;__getitem__
:可以通过下标方式获取数据
1.2 DataSet类的源码
1.3 DataLoader使用方法
- 定义dataset实例
- 设置读取数据batch的大小,常用128,256等等
- 设置shuffle:bool类型,表示是否在每次获取数据的时候提前打乱数据
1.4 数据集介绍
- 数据集:setiment.test.data,情感分析二分类数据,数据包含两列,文本和标签.
- 地址:https://github.com/bojone/bert4keras/tree/master/examples/datasets.
- 数据集格式如下图所示:
1.5 代码
- 步骤一:导入工具库
from torch.utils.data import Dataset, DataLoader
import pandas as pd
- 步骤二:定义数据读取类
class SentimentDataset(Dataset):
# 初始化
def __init__(self, path_to_file):
self.dataset = pd.read_csv(path_to_file, sep="\t", names=["text", "label"])
# 返回数据的长度
def __len__(self):
return len(self.dataset)
# 根据编号返回数据
def __getitem__(self, idx):
text = self.dataset.loc[idx, "text"] # 文本
label = self.dataset.loc[idx, "label"] # 标签
sample = {"text": text, "label": label} # 数据样本
return sample
- 步骤三:定义主函数
if __name__ == "__main__":
sentiment_dataset = SentimentDataset("sentiment.test.data")
print(sentiment_dataset.__getitem__(0)) # 查看第一条数据
- 步骤四:使用DataLoader批量读取数据
count = 0
for idx, batch_samples in enumerate(sentiment_dataloader):
text_batchs, text_labels = batch_samples["text"], batch_samples["label"]
print(idx,text_batchs)
count += 1
if count == 3:
break