原理
takeSample()函数和sample函数是一个原理,但是不使用相对比例采样,而是按设定的采样个数进行采样,同时返回结果不再是RDD,而是相当于对采样后的数据进行collect(),返回结果的集合为单机的数组。
图中,左侧的方框代表分布式的各个节点上的分区,右侧方框代表单机上返回的结果数组。通过takeSample对数据采样,设置为采样一份数据,返回结果为V1。
源码
/**
* Return a fixed-size sampled subset of this RDD in an array
*
* @param withReplacement whether sampling is done with replacement
* @param num size of the returned sample
* @param seed seed for the random number generator
* @return sample of specified size in an array
*/
def takeSample(withReplacement: Boolean,
num: Int,
seed: Long = Utils.random.nextLong): Array[T] = {
val numStDev = 10.0
if (num < 0) {
throw new IllegalArgumentException("Negative number of elements requested")
} else if (num == 0) {
return new Array[T](0)
}
val initialCount = this.count()
if (initialCount == 0) {
return new Array[T](0)
}
val maxSampleSize = Int.MaxValue - (numStDev * math.sqrt(Int.MaxValue)).toInt
if (num > maxSampleSize) {
throw new IllegalArgumentException("Cannot support a sample size > Int.MaxValue - " +
s"$numStDev * math.sqrt(Int.MaxValue)")
}
val rand = new Random(seed)
if (!withReplacement && num >= initialCount) {
return Utils.randomizeInPlace(this.collect(), rand)
}
val fraction = SamplingUtils.computeFractionForSampleSize(num, initialCount,
withReplacement)
var samples = this.sample(withReplacement, fraction, rand.nextInt()).collect()
// If the first sample didn't turn out large enough, keep trying to take samples;
// this shouldn't happen often because we use a big multiplier for the initial size
var numIters = 0
while (samples.length < num) {
logWarning(s"Needed to re-sample due to insufficient sample size. Repeat #$numIters")
samples = this.sample(withReplacement, fraction, rand.nextInt()).collect()
numIters += 1
}
Utils.randomizeInPlace(samples, rand).take(num)
}
上手使用
scala> val rdd = sc.makeRDD(1 to 100,3)
rdd: org.apache.spark.rdd.RDD[Int] = ParallelCollectionRDD[9] at makeRDD at <console>:27
scala> rdd.takeSample(true,10,9)
res10: Array[Int] = Array(56, 62, 52, 45, 93, 78, 71, 9, 60, 23)
scala> rdd.takeSample(true,10,10)
res12: Array[Int] = Array(70, 11, 20, 11, 28, 51, 57, 12, 100, 40)
scala> rdd.takeSample(true,10,11)
res13: Array[Int] = Array(18, 5, 44, 10, 51, 75, 8, 54, 79, 16)