FixedBucketSampler二

import numpy as np
import gluonnlp
from gluonnlp.data.sampler import ConstWidthBucket, _match_bucket_keys, _bucket_stats
import warnings

class Sampler(object):
    """Base class for samplers.

    All samplers should subclass `Sampler` and define `__iter__` and `__len__`
    methods.
    """
    def __iter__(self):
        raise NotImplementedError

    def __len__(self):
        raise NotImplementedError


class FixedBucketSampler(Sampler):
    r"""Assign each data sample to a fixed bucket based on its length.
    The bucket keys are either given or generated from the input sequence lengths.

    Parameters
    ----------
    lengths : list of int or list of tuple/list of int (成对的)序列长度值
        The length of the sequences in the input data sample.
    batch_size : int
        The batch size of the sampler.
    num_buckets : int or None, default 10 (分多少组数据)
        The number of buckets. This will not be used if bucket_keys is set.
    bucket_keys : None or list of int or list of tuple, default None (每组数据的右端临界值)
        The keys that will be used to create the buckets. It should usually be the lengths of the
        sequences. If it is None, the bucket_keys will be generated based on the maximum
        lengths of the data.
    ratio : float, default 0 (分组数据之后每组数据的放大比例,使得小batch更多)
        Ratio to scale up the batch size of smaller buckets.
        Assume the :math:`i` th key is :math:`K_i` ,
        the default batch size is :math:`B` , the ratio to scale the batch size is
        :math:`\alpha` and
        the batch size corresponds to the :math:`i` th bucket is :math:`B_i` . We have:

        .. math::

            B_i = \max(\alpha B \times \frac{\max_j sum(K_j)}{sum(K_i)}, B)

        Thus, setting this to a value larger than 0, like 0.5, will scale up the batch size of the
        smaller buckets.
    shuffle : bool, default False
        Whether to shuffle the batches.
    use_average_length : bool, default False
        False: each batch contains batch_size sequences, number of sequence elements varies.
        True: each batch contains batch_size elements, number of sequences varies. In this case,
        ratio option is ignored.
    num_shards : int, default 0 #分区,与多gpu有关
        If num_shards > 0, the sampled batch is split into num_shards smaller batches.
        The output will have structure of list(list(int)).
        If num_shards = 0, the output will have structure of list(int).
        This is useful in multi-gpu training and can potentially reduce the number of paddings.
        In general, it is set to the number of gpus.
    bucket_scheme : BucketScheme, default ConstWidthBucket #分区间(桶)方案
        It is used to generate bucket keys. It supports:
        ConstWidthBucket: all the buckets have the same width
        LinearWidthBucket: the width of ith  bucket follows :math:`w_i = \alpha * i + 1`
        ExpWidthBucket: the width of ith bucket follows
        :math:`w_i` = bucket_len_step :math:`* w_{i-1}`
    Examples
    --------
    >>> lengths = [np.random.randint(1, 100) for _ in range(1000)]
    >>> sampler = gluonnlp.data.FixedBucketSampler(lengths, 8, ratio=0.5)
    >>> print(sampler.stats())
    FixedBucketSampler:
    -etc-
    """
    def __init__(self, lengths, batch_size, num_buckets=10, bucket_keys=None,
                 ratio=0, shuffle=False, use_average_length=False, num_shards=0,
                 bucket_scheme=ConstWidthBucket()):
        assert len(lengths) > 0, 'FixedBucketSampler does not support empty lengths.'
        assert batch_size > 0, 'Batch size must be larger than 0.'
        assert ratio >= 0, 'batch size scaling ratio cannot be negative.'
        self._batch_size = batch_size
        self._ratio = ratio
        self._lengths = np.array(lengths, dtype=np.int32)
        if self._lengths.ndim == 1:
            self._single_element = True
            attr_num = 1
        else:
            assert self._lengths.ndim == 2, \
                'Elements in lengths must be either int or tuple/list of int. ' \
                'Received lengths=%s' % str(lengths)
            self._single_element = False
            attr_num = self._lengths.shape[1]
        self._shuffle = shuffle
        self._num_shards = num_shards
        self._bucket_scheme = bucket_scheme
        max_lengths = self._lengths.max(axis=0)
        min_lengths = self._lengths.min(axis=0)
        if self._single_element:
            assert min_lengths > 0, 'Sequence lengths must all be larger than 0.'
        else:
            for _, ele in enumerate(min_lengths):
                assert ele > 0, 'Sequence lengths must all be larger than 0.'
        # Generate the buckets
        if bucket_keys is None:
            assert num_buckets > 0, 'num_buckets must be set when bucket_keys is None. Received ' \
                                    'num_buckets=%d' % num_buckets
            bucket_keys = bucket_scheme(max_lengths, min_lengths, num_buckets)
        else:
            if num_buckets is not None:
                warnings.warn('num_buckets will not be used if bucket_keys is not None. '
                              'bucket_keys=%s, num_buckets=%d' % (str(bucket_keys), num_buckets))
            assert len(bucket_keys) > 0
            if self._single_element:
                assert isinstance(bucket_keys[0], int)
            else:
                assert isinstance(bucket_keys[0], tuple)
                assert len(bucket_keys[0]) == attr_num
        bucket_keys = sorted(set(bucket_keys))
        # Assign instances to buckets
        bucket_sample_ids = _match_bucket_keys(bucket_keys, self._lengths)
        unused_bucket_keys = [key for key, sample_ids in zip(bucket_keys, bucket_sample_ids)
                              if len(sample_ids) == 0]
        if len(unused_bucket_keys) > 0:
            warnings.warn('Some buckets are empty and will be removed. Unused bucket keys=%s' %
                          str(unused_bucket_keys))
        # Remove empty buckets
        self._bucket_keys = [key for key, sample_ids in zip(bucket_keys, bucket_sample_ids)
                             if len(sample_ids) > 0]

        self._bucket_sample_ids = [sample_ids for sample_ids in bucket_sample_ids
                                   if len(sample_ids) > 0]
        if not use_average_length:
            scale_up_keys = [key if self._single_element else sum(key) for key
                             in self._bucket_keys]
            max_scale_up_key = max(scale_up_keys)
            self._bucket_batch_sizes = [max(int(max_scale_up_key / float(scale_up_key)
                                                * self._ratio * batch_size), batch_size)
                                        for scale_up_key in scale_up_keys]
        else:
            if ratio > 0.:
                warnings.warn('ratio=%f is ignored when use_average_length is True' % self._ratio)
            bucket_average_lengths, bucket_length_stds = _bucket_stats(self._bucket_sample_ids,
                                                                       self._lengths)
            self._bucket_batch_sizes = [max(int(batch_size / (average_length + length_std)), 1)
                                        for average_length, length_std
                                        in zip(bucket_average_lengths, bucket_length_stds)]
        self._batch_infos = []
        for bucket_id, sample_ids, bucket_batch_size in\
                zip(range(len(self._bucket_keys) - 1, -1, -1),
                        self._bucket_sample_ids[::-1],
                        self._bucket_batch_sizes[::-1]):
            for i in range(0, len(sample_ids), bucket_batch_size):
                self._batch_infos.append((bucket_id, i))

        if self._num_shards > 0:
            self._sampler_size = int(math.ceil(len(self._batch_infos) / float(self._num_shards)))
        else:
            self._sampler_size = len(self._batch_infos)

    def __iter__(self):
        if self._shuffle:
            np.random.shuffle(self._batch_infos)
            for bucket_id in range(len(self._bucket_keys)):
                np.random.shuffle(self._bucket_sample_ids[bucket_id])

        if self._num_shards > 0:
            for batch_idx in range(0, len(self._batch_infos), self._num_shards):
                if batch_idx + self._num_shards > len(self._batch_infos):
                    batch_idx = len(self._batch_infos) - self._num_shards
                batch = self._batch_infos[batch_idx: batch_idx + self._num_shards]
                bucket_ids, batch_begins = list(zip(*batch))
                batch_sizes = [self._bucket_batch_sizes[bucket_id] for bucket_id in bucket_ids]
                batch_ends = [min(batch_begin + batch_size,
                                  len(self._bucket_sample_ids[bucket_id]))
                              for bucket_id, batch_begin, batch_size in zip(bucket_ids,
                                                                            batch_begins,
                                                                            batch_sizes)]
                yield [self._bucket_sample_ids[bucket_id][batch_begin:batch_end]
                       for bucket_id, batch_begin, batch_end in zip(bucket_ids,
                                                                    batch_begins,
                                                                    batch_ends)]
        else:
            for bucket_id, batch_begin in self._batch_infos:
                batch_size = self._bucket_batch_sizes[bucket_id]
                batch_end = min(batch_begin + batch_size, len(self._bucket_sample_ids[bucket_id]))
                yield self._bucket_sample_ids[bucket_id][batch_begin:batch_end]

    def __len__(self):
        return self._sampler_size

    def stats(self):
        """Return a string representing the statistics of the bucketing sampler.

        Returns
        -------
        ret : str
            String representing the statistics of the buckets.
        """
        ret = '{name}:\n' \
            '  sample_num={sample_num}, batch_num={batch_num}\n' \
            '  key={bucket_keys}\n' \
            '  cnt={bucket_counts}\n' \
            '  batch_size={bucket_batch_sizes}'\
            .format(name=self.__class__.__name__,
                    sample_num=len(self._lengths),
                    batch_num=len(self._batch_infos),
                    bucket_keys=self._bucket_keys,
                    bucket_counts=[len(sample_ids) for sample_ids in self._bucket_sample_ids],
                    bucket_batch_sizes=self._bucket_batch_sizes)
        return ret


lengths = [np.random.randint(1, 100) for _ in range(1000)]
sampler = gluonnlp.data.FixedBucketSampler(lengths, 8, ratio=0.5, num_shards=4)
print('lengths: ', lengths)
print('样本的统计信息: ', sampler.stats())
print('总样本batch数: ', len(sampler))
for i in sampler:
    print('单个样本: ', i)
    break

结果:

lengths:  [54, 12, 12, 5, 80, 8, 33, 3, 22, 58, 86, 58, 66, 13, 54, 19, 74, 8, 56, 55, 93, 72, 21, 56, 93, 76, 11, 16, 24, 85, 30, 5, 47, 40, 7, 74, 5, 92, 7, 51, 90, 32, 97, 6, 64, 28, 98, 68, 62, 21, 89, 70, 27, 19, 50, 48, 3, 76, 71, 62, 77, 33, 3, 29, 75, 52, 54, 73, 35, 2, 74, 13, 73, 58, 87, 52, 52, 6, 32, 52, 67, 65, 21, 93, 8, 88, 98, 15, 84, 59, 34, 58, 37, 12, 25, 98, 3, 1, 39, 66, 55, 1, 56, 17, 96, 56, 58, 71, 61, 63, 69, 46, 43, 19, 26, 69, 33, 17, 95, 85, 68, 65, 45, 32, 46, 81, 88, 83, 57, 87, 45, 35, 61, 50, 85, 45, 58, 71, 72, 69, 6, 31, 19, 57, 14, 48, 58, 82, 62, 25, 5, 67, 34, 42, 38, 89, 86, 96, 35, 8, 25, 21, 17, 3, 24, 5, 24, 21, 7, 96, 4, 25, 94, 15, 20, 85, 31, 10, 60, 91, 41, 14, 67, 58, 13, 80, 81, 80, 64, 76, 86, 49, 68, 64, 92, 14, 91, 9, 98, 81, 67, 98, 33, 13, 98, 59, 32, 19, 83, 36, 7, 88, 96, 47, 63, 16, 56, 56, 80, 5, 35, 7, 16, 51, 60, 10, 67, 19, 2, 87, 40, 46, 14, 24, 90, 35, 51, 67, 40, 88, 25, 60, 72, 88, 14, 86, 98, 97, 42, 71, 34, 54, 31, 18, 12, 42, 16, 68, 26, 41, 60, 3, 76, 18, 3, 30, 46, 81, 64, 44, 12, 37, 88, 47, 59, 85, 77, 89, 39, 71, 55, 69, 51, 51, 84, 80, 78, 94, 66, 12, 64, 78, 54, 1, 82, 43, 93, 6, 93, 64, 15, 34, 47, 38, 53, 68, 83, 51, 43, 16, 10, 62, 63, 61, 77, 76, 28, 94, 96, 41, 75, 65, 45, 38, 45, 77, 39, 23, 51, 17, 19, 68, 2, 29, 82, 7, 93, 9, 52, 16, 72, 81, 85, 14, 59, 56, 21, 49, 7, 26, 86, 39, 20, 2, 20, 34, 3, 73, 84, 44, 65, 24, 38, 44, 10, 59, 79, 11, 32, 73, 22, 1, 32, 81, 29, 3, 26, 26, 55, 75, 64, 61, 85, 51, 16, 46, 99, 81, 21, 77, 62, 49, 10, 23, 85, 94, 54, 81, 88, 97, 33, 26, 17, 90, 43, 45, 21, 77, 51, 97, 33, 4, 48, 41, 94, 94, 63, 76, 78, 73, 98, 60, 34, 82, 1, 79, 33, 45, 69, 77, 83, 55, 24, 95, 94, 9, 97, 33, 27, 24, 59, 55, 78, 55, 99, 63, 77, 86, 85, 47, 3, 43, 32, 70, 15, 59, 15, 56, 84, 70, 15, 40, 64, 27, 93, 82, 90, 85, 81, 62, 40, 10, 40, 31, 98, 70, 57, 58, 30, 85, 38, 72, 96, 11, 85, 9, 33, 61, 73, 52, 95, 38, 12, 44, 18, 18, 1, 3, 38, 66, 40, 10, 83, 56, 57, 22, 8, 61, 36, 90, 33, 32, 13, 95, 26, 22, 60, 70, 48, 8, 76, 85, 95, 1, 20, 33, 86, 29, 13, 8, 97, 50, 52, 89, 43, 97, 48, 34, 86, 89, 72, 59, 72, 67, 17, 88, 19, 1, 6, 73, 72, 92, 59, 1, 74, 12, 21, 88, 8, 77, 82, 20, 97, 81, 57, 78, 30, 60, 19, 2, 49, 97, 38, 75, 40, 43, 63, 94, 6, 16, 83, 79, 81, 83, 12, 52, 11, 73, 86, 86, 55, 35, 64, 9, 55, 76, 14, 39, 56, 84, 3, 2, 32, 57, 44, 57, 25, 26, 29, 54, 60, 69, 19, 44, 11, 74, 88, 49, 30, 61, 10, 68, 57, 40, 97, 14, 73, 17, 88, 77, 89, 98, 69, 97, 70, 61, 22, 16, 2, 93, 89, 12, 12, 57, 25, 86, 21, 70, 14, 1, 37, 63, 67, 22, 32, 42, 54, 40, 12, 1, 7, 7, 60, 55, 34, 68, 83, 92, 49, 2, 34, 45, 81, 27, 40, 61, 10, 61, 65, 54, 30, 11, 8, 18, 96, 52, 3, 2, 15, 30, 76, 2, 66, 39, 30, 83, 39, 81, 5, 61, 22, 9, 55, 29, 68, 89, 28, 38, 21, 3, 44, 68, 9, 51, 29, 8, 65, 70, 97, 12, 36, 21, 7, 71, 12, 62, 40, 1, 81, 73, 38, 77, 88, 28, 34, 77, 60, 50, 70, 97, 53, 6, 93, 88, 51, 1, 60, 69, 73, 97, 93, 7, 32, 53, 34, 64, 17, 55, 64, 11, 69, 25, 55, 10, 91, 48, 94, 21, 51, 79, 53, 4, 2, 99, 60, 15, 75, 39, 49, 87, 85, 95, 80, 47, 79, 45, 80, 35, 72, 18, 5, 76, 63, 58, 39, 31, 16, 94, 6, 28, 25, 20, 28, 31, 85, 88, 64, 86, 37, 56, 56, 26, 18, 17, 22, 53, 90, 43, 35, 72, 39, 91, 53, 94, 33, 41, 50, 54, 32, 12, 18, 9, 80, 78, 75, 84, 24, 30, 18, 32, 25, 63, 81, 1, 1, 45, 54, 27, 82, 97, 32, 31, 13, 46, 78, 89, 55, 10, 44, 26, 8, 52, 79, 74, 40, 83, 5, 74, 42, 61, 86, 73, 59, 65, 74, 37, 86, 17, 34, 12, 88, 93, 91, 94, 45, 71, 42, 31, 89, 93, 20, 68, 56, 58, 3, 86, 54, 40, 80, 62, 85, 99, 36, 44, 9, 51, 70, 74, 72, 44, 78, 37, 88, 64, 60, 62, 51, 97, 27, 43, 39, 80, 77, 55, 50, 81, 42, 50, 18, 82, 11, 63, 31, 28, 9, 83, 53, 61, 52, 91, 71, 69, 42, 55, 70, 2, 10, 72, 61, 35, 75, 11, 19, 86, 84, 70, 97, 96, 21, 89, 5, 88, 18, 97, 38, 84, 12, 95, 46, 92, 21, 98, 22, 75, 87, 58, 80, 50, 89, 68, 24, 52, 67, 53, 84, 32, 38, 47, 22, 39, 9, 10, 21, 71, 26, 58, 38, 32, 73, 92]
样本的统计信息:  FixedBucketSampler:
  sample_num=1000, batch_num=107
  key=[18, 27, 36, 45, 54, 63, 72, 81, 90, 99]
  cnt=[190, 82, 85, 88, 81, 107, 89, 94, 100, 84]
  batch_size=[22, 14, 11, 8, 8, 8, 8, 8, 8, 8]
总样本batch数:  27
单个样本:  [[20, 24, 37, 42, 46, 83, 86, 95], [104, 118, 157, 169, 172, 179, 194, 196], [198, 201, 204, 212, 246, 247, 287, 296], [298, 317, 318, 336, 386, 395, 399, 409]]

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值