BERT数据集迭代器

 # 数据集迭代器
    def __next__(self):
        if self.residue and self.index == self.n_batches:
            batches = self.dataset[self.index * self.batch_size : len(self.dataset)]
            self.index += 1
            batches = self._to_tensor(batches)
            return batches

        elif self.index > self.n_batches:
            self.index = 0
            raise StopIteration

        else:
            batches = self.dataset[self.index * self.batch_size : (self.index + 1) * self.batch_size]
            self.index += 1
            batches = self._to_tensor(batches)
            return batches

这段代码是一个迭代器的__next__()方法实现。在这个方法中,首先判断是否存在剩余的数据(self.residue)且当前批次的索引等于总批次数(self.index == self.n_batches)。如果满足条件,则表示当前是最后一个批次,且存在剩余数据无法构成完整的批次。此时,执行与之前相同的操作,将剩余数据转换为张量格式,并作为最后一个批次返回。

接下来,判断当前批次的索引是否大于总批次数(self.index > self.n_batches)。如果满足条件,则表示已经迭代完所有批次,需要抛出StopIteration异常来终止迭代器。

最后,如果以上两个条件都不满足,说明仍然有完整的批次可以迭代。在这种情况下,获取当前批次的数据(batches = self.dataset[self.index * self.batch_size : (self.index + 1) * self.batch_size]),将索引加1(self.index += 1),并将获取的数据转换为张量格式(batches = self._to_tensor(batches))。最后,返回这个批次的数据。

总之,这段代码实现了对数据集进行批次迭代的功能。当存在剩余数据时,将剩余数据作为最后一个批次返回;当迭代完所有批次时,抛出StopIteration异常;否则,返回下一个批次的数据。

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值