【零基础-3】PaddlePaddle学习Bert

概要

【零基础-1】PaddlePaddle学习Bert_ 一只博客-CSDN博客https://blog.csdn.net/qq_42276781/article/details/121488335【零基础-2】PaddlePaddle学习Bert_ 一只博客-CSDN博客https://blog.csdn.net/qq_42276781/article/details/121523268

Cell 7

# 创建dataloader
def create_dataloader(dataset,
                      mode='train',
                      batch_size=1,
                      batchify_fn=None,
                      trans_fn=None):
    if trans_fn:
        dataset = dataset.map(trans_fn)

    shuffle = True if mode == 'train' else False
    if mode == 'train':
        batch_sampler = paddle.io.DistributedBatchSampler(
            dataset, batch_size=batch_size, shuffle=shuffle)
    else:
        batch_sampler = paddle.io.BatchSampler(
            dataset, batch_size=batch_size, shuffle=shuffle)

    return paddle.io.DataLoader(
        dataset=dataset,
        batch_sampler=batch_sampler,
        collate_fn=batchify_fn,
        return_list=True)

Snippet  1

def create_dataloader(dataset,
                      mode='train',
                      batch_size=1,
                      batchify_fn=None,
                      trans_fn=None):

create_dataloader,创建数据加载器,输入数据集dataset、模式mode(默认为训练集)、batchify_fn(未知,暂时理解成batchify_function,即batch化的函数)、trans_fn(转换样本的函数)。

Snippet 2

if trans_fn:
        dataset = dataset.map(trans_fn)

如果传入了trans_fn,就使用trans_fn将dataset进行一个转换,dataset.map的api文档如下

dataset — PaddleNLP 文档https://paddlenlp.readthedocs.io/zh/latest/source/paddlenlp.datasets.dataset.html?highlight=dataset.map#paddlenlp.datasets.dataset.MapDataset.map

Snippet 3

shuffle = True if mode == 'train' else False

如果是训练集,就打乱,否则不打乱,这里的语法相当于C、Java的三目运算符

shuffle = mode == 'train' ? true : false 

Snippet 4

if mode == 'train':
        batch_sampler = paddle.io.DistributedBatchSampler(
            dataset, batch_size=batch_size, shuffle=shuffle)
    else:
        batch_sampler = paddle.io.BatchSampler(
            dataset, batch_size=batch_size, shuffle=shuffle)

如果传入的是训练集,则调用paddle.io.DistributedBatchSampler处理得到batch_sampler,如果传入的不是训练集,则调用paddle.io.BatchSampler处理得到batch_sampler。

这里为什么要得到batch_sampler呢?

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Toblerone_Wind

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值