import gluonnlp as nlp
from gluonnlp.data.dataset import CorpusDataset,SimpleDataset
from mxnet import np,npx
import mxnet as mx
import math
npx.set_np()
#从下面代码理解,bptt通过设置seq_length来限制文本的训练的长度,从而来解决RNN存在的梯度消逝的问题
batch_size = 5
bptt = 6
def wordtoword_splitter(s):
"""按字分割"""
return list(s)
def _slice_pad_length(num_items, length, overlap=0):
"""Calculate the padding length needed for sliced samples in order not to discard data.
计算在不丢弃批次不足的样本时需要补充多少填充项
Parameters
----------
num_items : int
Number of items in dataset before collating.
length : int
The length of each of the samples.
overlap : int, default 0
The extra number of items in current sample that should overlap with the
next sample.
Returns
-------
Length of paddings.
"""
if length <= overlap:
raise ValueError('length needs to be larger than overlap')
step = length - overlap #overlap部分在数据中只考虑1次,根据公式计算步长把它减掉后面就不考虑它了。
span = num_items - length #-length的目的是只考虑1次overlap,后面的span能不能给step整除,少则补。
residual = span % step
if residual:
return step - residual
else:
return 0
def slice_sequence(sequence, length, pad_last=False, pad_val='<pad>', overlap=0):
"""Slice a flat sequence of tokens into sequences tokens, with each
inner sequence's length equal to the specified `length`, taking into account the requested
sequence overlap.
Parameters
----------
sequence : list of object
A flat list of tokens.
length : int
The length of each of the samples.
pad_last : bool, default False
Whether to pad the last sequence when its length doesn't align. If the last sequence's
length doesn't align and ``pad_last`` is False, it will be dropped.
pad_val : object, default
The padding value to use when the padding of the last sequence is enabled. In general,
the type of ``pad_val`` should be the same as the tokens.
overlap : int, default 0
The extra number of items in current sample that should overlap with the
next sample.
Returns
-------
List of list of tokens, with the length of each inner list equal to `length`.
"""
if length <= overlap:
raise ValueError('length needs to be larger than overlap')
if pad_last:
pad_len = _slice_pad_length(len(sequence), length, overlap)
sequence = sequence + [pad_val] * pad_len
num_samples = (len(sequence) - length) // (length - overlap) + 1
# slice头部考虑overlap i*(length - overlap)
# slice尾部不考虑overlap,只考虑序列长度 (i)*(length-overlap)+length
return [sequence[i * (length - overlap): ((i + 1) * length - i * overlap)]
for i in range(num_samples)]
class CorpusBPTTBatchify:
"""Transform the dataset into batches of numericalized samples, in the way
that the recurrent states from last batch connects with the current batch
for each sample.
Each sample is of shape `(seq_len, batch_size)`. When `last_batch='keep'`, the first
dimension of last sample may be shorter than `seq_len`.
Parameters
----------
vocab : gluonnlp.Vocab
The vocabulary to use for numericalizing the dataset. Each token will be mapped to the
index according to the vocabulary.
seq_len : int
The length of each of the samples for truncated back-propagation-through-time (TBPTT).
batch_size : int
The number of samples in each batch.
last_batch : {'keep', 'discard'}
How to handle the last batch if the remaining length is less than `seq_len`.
- keep: A batch with less samples than previous batches is returned. vocab.padding_token
is used to pad the last batch based on batch size.
- discard: The last batch is discarded if it's smaller than `(seq_len, batch_size)`.
"""
def __init__(self,
vocab,
seq_len,
batch_size,
last_batch='keep'):
self._vocab = vocab
self._seq_len = seq_len
self._batch_size = batch_size
self._last_batch = last_batch
if last_batch not in ['keep', 'discard']:
raise ValueError(
'Got invalid last_batch: "{}". Must be "keep" or "discard".'.
format(last_batch))
if self._last_batch == 'keep':
if not self._vocab.padding_token:
raise ValueError('vocab.padding_token must be specified '
'in vocab when last_batch="keep".')
def __call__(self, corpus):
"""Batchify a dataset.
Parameters
----------
corpus : mxnet.gluon.data.Dataset
A flat dataset to be batchified.
Returns
-------
mxnet.gluon.data.Dataset
Batches of numericalized samples such that the recurrent states
from last batch connects with the current batch for each sample.
Each element of the Dataset is a tuple of size 2, specifying the
data and label for BPTT respectively. Both items are of the same
shape (seq_len, batch_size).
"""
if self._last_batch == 'keep':
coded = self._vocab[list(corpus)]
sample_len = math.ceil(float(len(coded)) / self._batch_size)
#填充分为2部分,
# _slice_pad_length(sample_len, self._seq_len + 1, 1) * self._batch_size 计算一个批次bptt需要填充的数量
# sample_len * self._batch_size - len(coded) 计算批次需要填充的数量
padding_size = _slice_pad_length(sample_len, self._seq_len + 1, 1) * \
+ sample_len * self._batch_size - len(coded)
coded.extend([self._vocab[self._vocab.padding_token]] * int(padding_size))
assert len(coded) % self._batch_size == 0
assert not _slice_pad_length(len(coded) / self._batch_size, self._seq_len + 1, 1)
else:
sample_len = len(corpus) // self._batch_size
coded = self._vocab[corpus[:sample_len * self._batch_size]]
data = mx.nd.array(coded).reshape((self._batch_size, -1)).T
batches = slice_sequence(data, self._seq_len + 1, overlap=1) #为什么+1 是因为bptt的长度为seq_len。
return SimpleDataset(batches).transform(_split_data_label, lazy=False)
def _split_data_label(x):
return x[:-1, :], x[1:, :]
if __name__ == '__main__':
sequence = list(np.arange(36).reshape(9,4))
print('sequence: \n', sequence)
t=slice_sequence(sequence,5, pad_last=False)
print('pad_last is False, slice sequence \n', t)
t1 = slice_sequence(sequence, 5, pad_last=True, overlap=1)
print('pad_last is True, slice sequence \n', t1)
train_dataset = CorpusDataset('testcorpus.txt', tokenizer=wordtoword_splitter,
flatten=True, eos='<eos>', skip_empty=False)
vocab = nlp.Vocab(nlp.data.Counter(train_dataset), padding_token='<pad>', bos_token=None)
print(train_dataset._data)
print(vocab.idx_to_token)
bptt_batchify = CorpusBPTTBatchify(
vocab, bptt, batch_size, last_batch='keep'
)
train_data = bptt_batchify(train_dataset)
print(train_data._data)
结果
sequence:
[array([0., 1., 2., 3.]), array([4., 5., 6., 7.]), array([ 8., 9., 10., 11.]), array([12., 13., 14., 15.]), array([16., 17., 18., 19.]), array([20., 21., 22., 23.]), array([24., 25., 26., 27.]), array([28., 29., 30., 31.]), array([32., 33., 34., 35.])]
pad_last is False, slice sequence
[[array([0., 1., 2., 3.]), array([4., 5., 6., 7.]), array([ 8., 9., 10., 11.]), array([12., 13., 14., 15.]), array([16., 17., 18., 19.])]]
pad_last is True, slice sequence
[[array([0., 1., 2., 3.]), array([4., 5., 6., 7.]), array([ 8., 9., 10., 11.]), array([12., 13., 14., 15.]), array([16., 17., 18., 19.])], [array([16., 17., 18., 19.]), array([20., 21., 22., 23.]), array([24., 25., 26., 27.]), array([28., 29., 30., 31.]), array([32., 33., 34., 35.])]]
['新', '浪', '娱', '乐', '讯', ' ', '北', '京', '时', '间', '9', '月', '4', '日', '消', '息', ',', '据', '《', '名', '利', '场', '》', '报', '道', '称', ',', '罗', '伯', '特', '·', '帕', '丁', '森', '确', '诊', '新', '冠', '阳', '性', ',', '他', '主', '演', '的', '新', '《', '蝙', '蝠', '侠', '》', '电', '影', '拍', '摄', '也', '暂', '停', '。', '<eos>', '<eos>', '不', '久', '前', ',', '《', '每', '日', '邮', '报', '》', '曝', '出', '该', '片', '有', '一', '名', '剧', '组', '人', '员', '感', '染', '新', '冠', ',', '刚', '在', '英', '国', '复', '工', '几', '天', '的', '影', '片', '拍', '摄', '也', '因', '此', '暂', '停', '(', '但', '报', '道', '用', '的', '是', 'c', 'r', 'e', 'w', ',', '而', '非', 'c', 'a', 's', 't', ',', '即', '是', '指', '幕', '后', '工', '作', '人', '员', '而', '非', '演', '员', ')', '。', '两', '小', '时', '后', ',', '华', '纳', '确', '认', '有', '一', '名', '《', '蝙', '蝠', '侠', '》', '制', '作', '团', '队', '成', '员', '感', '染', '了', '新', '冠', ',', '并', '简', '短', '确', '认', '了', '拍', '摄', '暂', '停', '一', '事', ',', '按', '惯', '例', '这', '份', '声', '明', '没', '有', '透', '露', '感', '染', '者', '身', '份', ',', '只', '表', '示', '其', '按', '规', '定', '在', '隔', '离', '中', '。', '<eos>', '<eos>', '而', '又', '是', '两', '小', '时', '后', ',', '《', '名', '利', '场', '》', '称', '另', '有', '高', '层', '消', '息', '源', '称', '是', '帕', '丁', '森', '新', '冠', '检', '测', '阳', '性', '。', '他', '的', '代', '理', '人', '尚', '未', '就', '此', '报', '道', '做', '出', '回', '复', '。', '<eos>', '<eos>']
['<unk>', '<pad>', '<eos>', ',', '新', '。', '《', '》', '冠', '名', '员', '报', '是', '有', '的', '一', '人', '停', '后', '感', '拍', '摄', '时', '暂', '染', '确', '称', '而', '道', 'c', '丁', '两', '也', '了', '他', '份', '作', '侠', '出', '利', '在', '场', '复', '小', '工', '帕', '影', '性', '息', '按', '日', '森', '此', '消', '演', '片', '蝙', '蝠', '认', '阳', '非', ' ', '4', '9', 'a', 'e', 'r', 's', 't', 'w', '·', '不', '中', '主', '久', '乐', '事', '京', '代', '伯', '但', '例', '做', '其', '几', '刚', '制', '前', '剧', '北', '华', '即', '又', '另', '只', '回', '因', '团', '国', '声', '天', '娱', '定', '尚', '就', '层', '幕', '并', '惯', '成', '指', '据', '明', '曝', '月', '未', '检', '每', '没', '测', '浪', '源', '特', '理', '用', '电', '短', '示', '离', '简', '纳', '组', '罗', '者', '英', '表', '规', '讯', '诊', '该', '身', '这', '透', '邮', '间', '队', '隔', '露', '高', '(', ')']
9
(
[[ 57. 23. 97. 92. 1.]
[ 37. 17. 145. 12. 1.]
[ 7. 149. 109. 31. 1.]
[125. 80. 10. 43. 1.]
[ 46. 11. 19. 22. 1.]
[ 20. 28. 24. 18. 1.]]
<NDArray 6x5 @cpu(0)>,
[[ 37. 17. 145. 12. 1.]
[ 7. 149. 109. 31. 1.]
[125. 80. 10. 43. 1.]
[ 46. 11. 19. 22. 1.]
[ 20. 28. 24. 18. 1.]
[ 21. 124. 33. 3. 1.]]
<NDArray 6x5 @cpu(0)>)
text_corpus.txt 内容:
新浪娱乐讯 北京时间9月4日消息,据《名利场》报道称,罗伯特·帕丁森确诊新冠阳性,他主演的新《蝙蝠侠》电影拍摄也暂停。 不久前,《每日邮报》曝出该片有一名剧组人员感染新冠,刚在英国复工几天的影片拍摄也因此暂停(但报道用的是crew,而非cast,即是指幕后工作人员而非演员)。两小时后,华纳确认有一名《蝙蝠侠》制作团队成员感染了新冠,并简短确认了拍摄暂停一事,按惯例这份声明没有透露感染者身份,只表示其按规定在隔离中。 而又是两小时后,《名利场》称另有高层消息源称是帕丁森新冠检测阳性。他的代理人尚未就此报道做出回复。