nlp.data.batchify.CorpusBPTTBatchify的学习理解

探讨了在自然语言处理任务中,如何利用截断的反向传播通过时间(BPTT)技术来处理长文本序列,以解决递归神经网络(RNN)中的梯度消失问题。介绍了自定义数据集批处理方法,包括序列切片、填充和批处理策略,以及如何使用Gluon NLP库实现这些技术。
摘要由CSDN通过智能技术生成
import gluonnlp as nlp
from gluonnlp.data.dataset import CorpusDataset,SimpleDataset
from mxnet import np,npx
import mxnet as mx
import math
npx.set_np()
#从下面代码理解,bptt通过设置seq_length来限制文本的训练的长度,从而来解决RNN存在的梯度消逝的问题
batch_size = 5
bptt = 6

def wordtoword_splitter(s):
    """按字分割"""
    return list(s)



def _slice_pad_length(num_items, length, overlap=0):
    """Calculate the padding length needed for sliced samples in order not to discard data.
    计算在不丢弃批次不足的样本时需要补充多少填充项
    Parameters
    ----------
    num_items : int
        Number of items in dataset before collating.
    length : int
        The length of each of the samples.
    overlap : int, default 0
        The extra number of items in current sample that should overlap with the
        next sample.

    Returns
    -------
    Length of paddings.

    """
    if length <= overlap:
        raise ValueError('length needs to be larger than overlap')

    step = length - overlap #overlap部分在数据中只考虑1次,根据公式计算步长把它减掉后面就不考虑它了。
    span = num_items - length #-length的目的是只考虑1次overlap,后面的span能不能给step整除,少则补。
    residual = span % step
    if residual:
        return step - residual
    else:
        return 0
def slice_sequence(sequence, length, pad_last=False, pad_val='<pad>', overlap=0):
    """Slice a flat sequence of tokens into sequences tokens, with each
    inner sequence's length equal to the specified `length`, taking into account the requested
    sequence overlap.

    Parameters
    ----------
    sequence : list of object
        A flat list of tokens.
    length : int
        The length of each of the samples.
    pad_last : bool, default False
        Whether to pad the last sequence when its length doesn't align. If the last sequence's
        length doesn't align and ``pad_last`` is False, it will be dropped.
    pad_val : object, default
        The padding value to use when the padding of the last sequence is enabled. In general,
        the type of ``pad_val`` should be the same as the tokens.
    overlap : int, default 0
        The extra number of items in current sample that should overlap with the
        next sample.

    Returns
    -------
    List of list of tokens, with the length of each inner list equal to `length`.

    """
    if length <= overlap:
        raise ValueError('length needs to be larger than overlap')

    if pad_last:
        pad_len = _slice_pad_length(len(sequence), length, overlap)
        sequence = sequence + [pad_val] * pad_len
    num_samples = (len(sequence) - length) // (length - overlap) + 1
    # slice头部考虑overlap i*(length - overlap)
    # slice尾部不考虑overlap,只考虑序列长度 (i)*(length-overlap)+length
    return [sequence[i * (length - overlap): ((i + 1) * length - i * overlap)]
            for i in range(num_samples)]

class CorpusBPTTBatchify:
    """Transform the dataset into batches of numericalized samples, in the way
    that the recurrent states from last batch connects with the current batch
    for each sample.

    Each sample is of shape `(seq_len, batch_size)`. When `last_batch='keep'`, the first
    dimension of last sample may be shorter than `seq_len`.

    Parameters
    ----------
    vocab : gluonnlp.Vocab
        The vocabulary to use for numericalizing the dataset. Each token will be mapped to the
        index according to the vocabulary.
    seq_len : int
        The length of each of the samples for truncated back-propagation-through-time (TBPTT).
    batch_size : int
        The number of samples in each batch.
    last_batch : {'keep', 'discard'}
        How to handle the last batch if the remaining length is less than `seq_len`.

        - keep: A batch with less samples than previous batches is returned. vocab.padding_token
          is used to pad the last batch based on batch size.

        - discard: The last batch is discarded if it's smaller than `(seq_len, batch_size)`.
    """

    def __init__(self,
                 vocab,
                 seq_len,
                 batch_size,
                 last_batch='keep'):
        self._vocab = vocab
        self._seq_len = seq_len
        self._batch_size = batch_size
        self._last_batch = last_batch

        if last_batch not in ['keep', 'discard']:
            raise ValueError(
                'Got invalid last_batch: "{}". Must be "keep" or "discard".'.
                format(last_batch))

        if self._last_batch == 'keep':
            if not self._vocab.padding_token:
                raise ValueError('vocab.padding_token must be specified '
                                 'in vocab when last_batch="keep".')

    def __call__(self, corpus):
        """Batchify a dataset.

        Parameters
        ----------
        corpus : mxnet.gluon.data.Dataset
            A flat dataset to be batchified.

        Returns
        -------
        mxnet.gluon.data.Dataset
            Batches of numericalized samples such that the recurrent states
            from last batch connects with the current batch for each sample.
            Each element of the Dataset is a tuple of size 2, specifying the
            data and label for BPTT respectively. Both items are of the same
            shape (seq_len, batch_size).
        """
        if self._last_batch == 'keep':
            coded = self._vocab[list(corpus)]
            sample_len = math.ceil(float(len(coded)) / self._batch_size)
            #填充分为2部分,
            # _slice_pad_length(sample_len, self._seq_len + 1, 1) * self._batch_size 计算一个批次bptt需要填充的数量
            # sample_len * self._batch_size - len(coded) 计算批次需要填充的数量
            padding_size = _slice_pad_length(sample_len, self._seq_len + 1, 1) * \
                 + sample_len * self._batch_size - len(coded)
            coded.extend([self._vocab[self._vocab.padding_token]] * int(padding_size))
            assert len(coded) % self._batch_size == 0
            assert not _slice_pad_length(len(coded) / self._batch_size, self._seq_len + 1, 1)
        else:
            sample_len = len(corpus) // self._batch_size
            coded = self._vocab[corpus[:sample_len * self._batch_size]]
        data = mx.nd.array(coded).reshape((self._batch_size, -1)).T
        batches = slice_sequence(data, self._seq_len + 1, overlap=1) #为什么+1 是因为bptt的长度为seq_len。

        return SimpleDataset(batches).transform(_split_data_label, lazy=False)

def _split_data_label(x):
    return x[:-1, :], x[1:, :]

if __name__ == '__main__':
    sequence = list(np.arange(36).reshape(9,4))
    print('sequence: \n', sequence)
    t=slice_sequence(sequence,5, pad_last=False)
    print('pad_last is False, slice sequence \n', t)
    t1 = slice_sequence(sequence, 5, pad_last=True, overlap=1)
    print('pad_last is True, slice sequence \n', t1)

    train_dataset = CorpusDataset('testcorpus.txt', tokenizer=wordtoword_splitter,
                                  flatten=True, eos='<eos>', skip_empty=False)
    vocab = nlp.Vocab(nlp.data.Counter(train_dataset), padding_token='<pad>', bos_token=None)
    print(train_dataset._data)
    print(vocab.idx_to_token)
    bptt_batchify = CorpusBPTTBatchify(
        vocab, bptt, batch_size, last_batch='keep'
    )

    train_data = bptt_batchify(train_dataset)
    print(train_data._data)

结果

sequence: 
 [array([0., 1., 2., 3.]), array([4., 5., 6., 7.]), array([ 8.,  9., 10., 11.]), array([12., 13., 14., 15.]), array([16., 17., 18., 19.]), array([20., 21., 22., 23.]), array([24., 25., 26., 27.]), array([28., 29., 30., 31.]), array([32., 33., 34., 35.])]
pad_last is False, slice sequence 
 [[array([0., 1., 2., 3.]), array([4., 5., 6., 7.]), array([ 8.,  9., 10., 11.]), array([12., 13., 14., 15.]), array([16., 17., 18., 19.])]]
pad_last is True, slice sequence 
 [[array([0., 1., 2., 3.]), array([4., 5., 6., 7.]), array([ 8.,  9., 10., 11.]), array([12., 13., 14., 15.]), array([16., 17., 18., 19.])], [array([16., 17., 18., 19.]), array([20., 21., 22., 23.]), array([24., 25., 26., 27.]), array([28., 29., 30., 31.]), array([32., 33., 34., 35.])]]
['新', '浪', '娱', '乐', '讯', ' ', '北', '京', '时', '间', '9', '月', '4', '日', '消', '息', ',', '据', '《', '名', '利', '场', '》', '报', '道', '称', ',', '罗', '伯', '特', '·', '帕', '丁', '森', '确', '诊', '新', '冠', '阳', '性', ',', '他', '主', '演', '的', '新', '《', '蝙', '蝠', '侠', '》', '电', '影', '拍', '摄', '也', '暂', '停', '。', '<eos>', '<eos>', '不', '久', '前', ',', '《', '每', '日', '邮', '报', '》', '曝', '出', '该', '片', '有', '一', '名', '剧', '组', '人', '员', '感', '染', '新', '冠', ',', '刚', '在', '英', '国', '复', '工', '几', '天', '的', '影', '片', '拍', '摄', '也', '因', '此', '暂', '停', '(', '但', '报', '道', '用', '的', '是', 'c', 'r', 'e', 'w', ',', '而', '非', 'c', 'a', 's', 't', ',', '即', '是', '指', '幕', '后', '工', '作', '人', '员', '而', '非', '演', '员', ')', '。', '两', '小', '时', '后', ',', '华', '纳', '确', '认', '有', '一', '名', '《', '蝙', '蝠', '侠', '》', '制', '作', '团', '队', '成', '员', '感', '染', '了', '新', '冠', ',', '并', '简', '短', '确', '认', '了', '拍', '摄', '暂', '停', '一', '事', ',', '按', '惯', '例', '这', '份', '声', '明', '没', '有', '透', '露', '感', '染', '者', '身', '份', ',', '只', '表', '示', '其', '按', '规', '定', '在', '隔', '离', '中', '。', '<eos>', '<eos>', '而', '又', '是', '两', '小', '时', '后', ',', '《', '名', '利', '场', '》', '称', '另', '有', '高', '层', '消', '息', '源', '称', '是', '帕', '丁', '森', '新', '冠', '检', '测', '阳', '性', '。', '他', '的', '代', '理', '人', '尚', '未', '就', '此', '报', '道', '做', '出', '回', '复', '。', '<eos>', '<eos>']
['<unk>', '<pad>', '<eos>', ',', '新', '。', '《', '》', '冠', '名', '员', '报', '是', '有', '的', '一', '人', '停', '后', '感', '拍', '摄', '时', '暂', '染', '确', '称', '而', '道', 'c', '丁', '两', '也', '了', '他', '份', '作', '侠', '出', '利', '在', '场', '复', '小', '工', '帕', '影', '性', '息', '按', '日', '森', '此', '消', '演', '片', '蝙', '蝠', '认', '阳', '非', ' ', '4', '9', 'a', 'e', 'r', 's', 't', 'w', '·', '不', '中', '主', '久', '乐', '事', '京', '代', '伯', '但', '例', '做', '其', '几', '刚', '制', '前', '剧', '北', '华', '即', '又', '另', '只', '回', '因', '团', '国', '声', '天', '娱', '定', '尚', '就', '层', '幕', '并', '惯', '成', '指', '据', '明', '曝', '月', '未', '检', '每', '没', '测', '浪', '源', '特', '理', '用', '电', '短', '示', '离', '简', '纳', '组', '罗', '者', '英', '表', '规', '讯', '诊', '该', '身', '这', '透', '邮', '间', '队', '隔', '露', '高', '(', ')']
9
(
[[ 57.  23.  97.  92.   1.]
 [ 37.  17. 145.  12.   1.]
 [  7. 149. 109.  31.   1.]
 [125.  80.  10.  43.   1.]
 [ 46.  11.  19.  22.   1.]
 [ 20.  28.  24.  18.   1.]]
<NDArray 6x5 @cpu(0)>, 
[[ 37.  17. 145.  12.   1.]
 [  7. 149. 109.  31.   1.]
 [125.  80.  10.  43.   1.]
 [ 46.  11.  19.  22.   1.]
 [ 20.  28.  24.  18.   1.]
 [ 21. 124.  33.   3.   1.]]
<NDArray 6x5 @cpu(0)>)

 

text_corpus.txt 内容:

新浪娱乐讯 北京时间9月4日消息,据《名利场》报道称,罗伯特·帕丁森确诊新冠阳性,他主演的新《蝙蝠侠》电影拍摄也暂停。

  不久前,《每日邮报》曝出该片有一名剧组人员感染新冠,刚在英国复工几天的影片拍摄也因此暂停(但报道用的是crew,而非cast,即是指幕后工作人员而非演员)。两小时后,华纳确认有一名《蝙蝠侠》制作团队成员感染了新冠,并简短确认了拍摄暂停一事,按惯例这份声明没有透露感染者身份,只表示其按规定在隔离中。

  而又是两小时后,《名利场》称另有高层消息源称是帕丁森新冠检测阳性。他的代理人尚未就此报道做出回复。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值