十九、Pytorch中的数据加载

1. Pytorch中DataSet的使用方法

1.1 DataSet加载数据的方法

  • DataSet是Pytorch中用来表示数据集的一个抽象类,在torch中提供了数据集的基类torch.utils.data.Dataset,继承这个基类,我们能够快速地实现对数据的加载**.**

    __len__:返回数据集大小; __getitem__:可以通过下标方式获取数据

1.2 DataSet类的源码

在这里插入图片描述

1.3 DataLoader使用方法

  • 定义dataset实例
  • 设置读取数据batch的大小,常用128,256等等
  • 设置shuffle:bool类型,表示是否在每次获取数据的时候提前打乱数据

1.4 数据集介绍

  • 数据集:setiment.test.data,情感分析二分类数据,数据包含两列,文本和标签.
  • 地址:https://github.com/bojone/bert4keras/tree/master/examples/datasets.
  • 数据集格式如下图所示:

在这里插入图片描述

1.5 代码

  • 步骤一:导入工具库
from torch.utils.data import Dataset, DataLoader
import pandas as pd
  • 步骤二:定义数据读取类
class SentimentDataset(Dataset):
    # 初始化
    def __init__(self, path_to_file):
        self.dataset = pd.read_csv(path_to_file, sep="\t", names=["text", "label"])

    # 返回数据的长度
    def __len__(self):
        return len(self.dataset)

    # 根据编号返回数据
    def __getitem__(self, idx):
        text = self.dataset.loc[idx, "text"]    # 文本
        label = self.dataset.loc[idx, "label"]  # 标签
        sample = {"text": text, "label": label} # 数据样本
        return sample
  • 步骤三:定义主函数
if __name__ == "__main__":
    sentiment_dataset = SentimentDataset("sentiment.test.data")
    print(sentiment_dataset.__getitem__(0)) # 查看第一条数据
  • 步骤四:使用DataLoader批量读取数据
count = 0
for idx, batch_samples in enumerate(sentiment_dataloader):
    text_batchs, text_labels = batch_samples["text"], batch_samples["label"]
    print(idx,text_batchs)
    count += 1
    if count == 3:
        break
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值