序列数据采样方法的封装调用--python代码实现

本篇代码基于本人上一篇《序列数据的采样方法:python代码实现》 

上一篇网址:序列数据的采样方法:python代码实现-CSDN博客

class SeqDataLoader:
    def __init__(self, batch_size, num_steps, use_random_iter, max_tokens):
        #根据设置的标志进行判断采用哪种采样方法
        if use_random_iter:
            self.data_iter_fn = seq_data_iter_random
        else:
            self.data_iter_fn = seq_data_iter_sequential
            
        #调用函数获取语料库,词汇表
        self.corpus, self.vocab = load_corpus_time_machine(max_tokens)
        #初始化赋值
        self.batch_size, self.num_steps = batch_size, num_steps
        
    #定义函数:返回选取的采样方法的输出结果
    def __iter__(self):
        return self.data_iter_fn(self.corpus, self.batch_size, self.num_steps)


# 返回数据迭代器和词汇表
def load_data_time_machine(batch_size, num_steps, user_random_iter=False, max_tokens=10000):
    #调用实例对象SeqDataLoader()
    data_iter = SeqDataLoader(batch_size, num_steps, user_random_iter, max_tokens)
    return data_iter, data_iter.vocab 



#测试上面的封装代码
batch_size, num_steps =2,10
loader, vocab = load_data_time_machine(batch_size,
                                       num_steps)


for i in loader:
    print(i)
    break

 测试输出结果:

(tensor([[ 4, 15,  9,  5,  6,  2,  1, 21, 19,  1],
        [ 1, 17,  4,  8,  1,  4, 12,  7,  6, 18]]), tensor([[15,  9,  5,  6,  2,  1, 21, 19,  1,  9],
        [17,  4,  8,  1,  4, 12,  7,  6, 18,  3]]))
  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值