glue.truncate_seqs_equal, glue.concat_sequences

from gluonnlp.data.bert import glue


seqs = [[1, 2, 3], [4, 5, 6]]
print(glue.truncate_seqs_equal(seqs, 4))
seqs = [[1, 2, 3], [4, 5, 6]]
print(glue.truncate_seqs_equal(seqs, 5))

seqs =[['is', 'this', 'jacksonville', '?'], ['no', 'it', 'is', 'not', '.']]
separators=[['SEP'], ['SEP'], ['CLS']]
print(glue.concat_sequences(seqs, separators, seq_mask=0, separator_mask=1))

结果:
[[1, 2], [4, 5]]
[[1, 2, 3], [4, 5]]
(['is', 'this', 'jacksonville', '?', 'SEP', 'no', 'it', 'is', 'not', '.', 'SEP', 'CLS'], [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 2], [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1])


#gluonnlp.data.bert.glue源代码
__all__ = ['truncate_seqs_equal', 'concat_sequences']

import collections
import itertools
import numpy.ma as ma


def truncate_seqs_equal(sequences, max_len):
    """truncate a list of seqs equally so that the total length equals max length.
    按设置最大长度进行均匀截取一组序列,
    要求小于均值的不截取,大于均值的均匀截取。
    Parameters
    ----------
    sequences : list of list of object
        Sequences of tokens, each of which is an iterable of tokens.
    max_len : int
        Max length to be truncated to.

    Returns
    -------
    list : list of truncated sequence keeping the origin order

    Examples
    --------
    >>> seqs = [[1, 2, 3], [4, 5, 6]]
    >>> truncate_seqs_equal(seqs, 6)
    [[1, 2, 3], [4, 5, 6]]
    >>> seqs = [[1, 2, 3], [4, 5, 6]]
    >>> truncate_seqs_equal(seqs, 4)
    [[1, 2], [4, 5]]
    >>> seqs = [[1, 2, 3], [4, 5, 6]]
    >>> truncate_seqs_equal(seqs, 3)
    [[1, 2], [4]]
    """
    assert isinstance(sequences, list)
    lens = list(map(len, sequences))
    if sum(lens) <= max_len:
        return sequences

    lens = ma.masked_array(lens, mask=[0] * len(lens))
    while True:
        argmin = lens.argmin()
        minval = lens[argmin]
        #计算多个序列的商和余数(考虑到有些序列已经被标记成mask后,因此要减去sum(lens.mask))
        quotient, remainder = divmod(max_len, len(lens) - sum(lens.mask))
        if minval <= quotient:  # Ignore values that don't need truncation
            #值小于商的序列不用截断
            lens.mask[argmin] = 1
            max_len -= minval
        else:  # Truncate all
            #值大于商的序列,按商截断,多余的余数按顺序分配给靠前的序列
            lens.data[~lens.mask] = [
                quotient + 1 if i < remainder else quotient for i in range(lens.count())
            ]
            break
    sequences = [seq[:length] for (seq, length) in zip(sequences, lens.data.tolist())]
    return sequences


def concat_sequences(seqs, separators, seq_mask=0, separator_mask=1):
    """Concatenate sequences in a list into a single sequence, using specified separators.
    用特殊字符对多个序列进行拼接,满足bert的输入?
    Example 1:
    seqs: [['is', 'this' ,'jacksonville', '?'], ['no' ,'it' ,'is' ,'not', '.']]
    separator: [[SEP], [SEP], [CLS]]
    seq_mask: 0
    separator_mask: 1

    Returns:
    tokens:      is this jacksonville ? [SEP] no it is not . [SEP] [CLS]
    segment_ids: 0  0    0            0  0    1  1  1  1   1 1     2
    p_mask:      0  0    0            0  1    0  0  0  0   0 1     1

    Example 2:
    separator_mask can also be a list.
    seqs: [['is', 'this' ,'jacksonville', '?'], ['no' ,'it' ,'is' ,'not', '.']]
    separator: [[SEP], [SEP], [CLS]]
    seq_mask: 0
    separator_mask: [[1], [1], [0]]

    Returns:
    tokens:     'is this jacksonville ? [SEP] no it is not . [SEP] [CLS]'
    segment_ids: 0  0    0            0  0    1  1  1  1   1 1     2
    p_mask:      1  1    1            1  1    0  0  0  0   0 1     0

    Example 3:
    seq_mask can also be a list.
    seqs: [['is', 'this' ,'jacksonville', '?'], ['no' ,'it' ,'is' ,'not', '.']]
    separator: [[SEP], [SEP], [CLS]]
    seq_mask: [[1, 1, 1, 1], [0, 0, 0, 0, 0]]
    separator_mask: [[1], [1], [0]]

    Returns:
    tokens:     'is this jacksonville ? [SEP] no it is not . [SEP] [CLS]'
    segment_ids: 0  0    0            0  0    1  1  1  1   1 1     2
    p_mask:      1  1    1            1  1    0  0  0  0   0 1     0

    Parameters
    ----------
    seqs : list of list of object
        sequences to be concatenated
    separator : list of list of object
        The special tokens to separate sequences.
    seq_mask : int or list of list of int
        A single mask value for all sequence items or a list of values for each item in sequences
    separator_mask : int or list of list of int
        A single mask value for all separators or a list of values for each separator

    Returns
    -------
    np.array: input token ids in 'int32', shape (batch_size, seq_length)
    np.array: segment ids in 'int32', shape (batch_size, seq_length)
    np.array: mask for special tokens
    """
    assert isinstance(seqs, collections.abc.Iterable) and len(seqs) > 0
    assert isinstance(seq_mask, (list, int))
    assert isinstance(separator_mask, (list, int))
    concat = sum((seq + sep for sep, seq in itertools.zip_longest(separators, seqs, fillvalue=[])),
                 [])
    segment_ids = sum(
        ([i] * (len(seq) + len(sep))
         for i, (sep, seq) in enumerate(itertools.zip_longest(separators, seqs, fillvalue=[]))),
        [])
    if isinstance(seq_mask, int):
        seq_mask = [[seq_mask] * len(seq) for seq in seqs]
    if isinstance(separator_mask, int):
        separator_mask = [[separator_mask] * len(sep) for sep in separators]

    p_mask = sum((s_mask + mask for sep, seq, s_mask, mask in itertools.zip_longest(
        separators, seqs, seq_mask, separator_mask, fillvalue=[])), [])
    return concat, segment_ids, p_mask
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值