本篇代码基于本人上一篇《序列数据的采样方法:python代码实现》
上一篇网址:序列数据的采样方法:python代码实现-CSDN博客
class SeqDataLoader:
def __init__(self, batch_size, num_steps, use_random_iter, max_tokens):
#根据设置的标志进行判断采用哪种采样方法
if use_random_iter:
self.data_iter_fn = seq_data_iter_random
else:
self.data_iter_fn = seq_data_iter_sequential
#调用函数获取语料库,词汇表
self.corpus, self.vocab = load_corpus_time_machine(max_tokens)
#初始化赋值
self.batch_size, self.num_steps = batch_size, num_steps
#定义函数:返回选取的采样方法的输出结果
def __iter__(self):
return self.data_iter_fn(self.corpus, self.batch_size, self.num_steps)
# 返回数据迭代器和词汇表
def load_data_time_machine(batch_size, num_steps, user_random_iter=False, max_tokens=10000):
#调用实例对象SeqDataLoader()
data_iter = SeqDataLoader(batch_size, num_steps, user_random_iter, max_tokens)
return data_iter, data_iter.vocab
#测试上面的封装代码
batch_size, num_steps =2,10
loader, vocab = load_data_time_machine(batch_size,
num_steps)
for i in loader:
print(i)
break
测试输出结果:
(tensor([[ 4, 15, 9, 5, 6, 2, 1, 21, 19, 1], [ 1, 17, 4, 8, 1, 4, 12, 7, 6, 18]]), tensor([[15, 9, 5, 6, 2, 1, 21, 19, 1, 9], [17, 4, 8, 1, 4, 12, 7, 6, 18, 3]]))