在进行机器学习时,经常需要打乱样本,这种时候Python中叒有第三方库提供了这个功能——sklearn.utils.shuffle。
Shuffle arrays or sparse matrices in a consistent way. This is a convenience alias to
resample(*arrays, replace=False)
to do random permutations of the collections.
函数参数
Parameters
参数 | 介绍 |
---|---|
*array | 带索引的序列,可以是arrays, lists, dataframes或scipy sparse matrices |
random_state | int,随机量,就是一个random seed。如果是int,该参数作为random seed的值;如果是None,随机生成器就是一个np.random实例 |
n_sample | int,默认为None,输出的样本数目。如果是空,则样本数目会设置为array的第一维元素数 |
Returns
参数 | 介绍 |
---|---|
shuffled_arrays | 带索引的序列,是一个view(也就是说不会改变输入array) |
Examples
解释:例程中建立了3个带索引的序列:array, array和sparse matrix。然后将它们作为一个元组进行shuffle,其中random_state=0表示它们的打乱方式是方式0。这个打乱方式不理解的可以看一下np.random.seed的介绍或者是看我接下来对源码的解析。
源码
shuffle
def shuffle(*arrays, **options):
options['replace'] = False
return resample(*arrays, **options)
Are you kidding? 这是个“空壳函数”。唯一的作用就是将一个参数
replace
置为了False,好让shuffle过程中不影响输入array(不过要记住这个replace
,这是sklearn.utils.shuffle
和sklearn.utils.resample
唯一的区别)。
那么下面来看resample函数。
resample
def resample(*arrays, **options):
'''先是类型检测部分,可以跳过'''
random_state = check_random_state(options.pop('random_state', None)) # 此处注意:返回类型变了,变成:np.random.mtrand._rand或np.random.RandomState(seed)或seed
replace = option.pop('replace', True) # 如果没有‘replace’则返回True
max_n_samples = options.pop('n_samples', None)
if options:
raise ValueError("Unexpected kw arguments: %r" % options.keys())
if len(arrays) == 0:
return None
first = arrays[0]
n_samples = first.shape[0] if hasattr(first, 'shape') else len(first)
if max_n_samples is None:
max_n_samples = n_samples
elif (max_n_samples > n_samples) and (not replace):
raise ValueError("Cannot sample %d out of arrays with dim %d when replace is False" % (max_n_samples, n_samples))
check_consistent_length(*array)
'''开始正文'''
'''重排索引'''
if replace:
indices = random_state.randint(0, n_samples, size=(max_n_samples,)) # 创建新的随机序列索引indices
else:
indices = np.arange(n_samples)
random_state.shuffle(indices)
indices = indices[:max_n_samples]
# convert sparse matrices to CSR for row-based indexing
arrays = [a.tocsr() if issparse(a) else a for a in arrays]
'''根据indices对arrays进行采样'''
resampled_arrays = [safe_indexing(a, indices) for a in arrays]
'''分两种情况,一种是输入的*arrays参数只有一个序列,另一种是输入的*arrays参数是一元组的序列'''
if len(resampled_arrays) == 1:
# syntactic sugar for the unit argument case
return resampled_arrays[0]
else:
return resampled_arrays
解释:代码分为三部分:
- 类型检测
- 构建重排后的索引
- 根据索引输出序列
random_state在函数中用于产生随机索引