parallelize方法源码分析
-
parallelize用于通过本地集合创建rdd
val rdd1 = sc.parallelize(List(1,2,3,4,5,6)) -
现在研究一下,这个方法底层都做了什么事情
通过ctrl+b点进入,底层方法如下,可以发现它设置了一个默认参数numslices=defaultParallelism表示分区数量,以及新建了一个ParallelCollectionRDD实例对象。
def parallelize[T: ClassTag](
seq: Seq[T],
numSlices: Int = defaultParallelism): RDD[T] = withScope {
assertNotStopped()
new ParallelCollectionRDD[T](this, seq, numSlices, Map[Int, Seq[String]]())
}
- 首先我们再探究一下返回的ParallelCollectionRDD对象到底是个什么玩意。
- 点进入查看ParallelCollectionRDD类源码
private[spark] class ParallelCollectionRDD[T: ClassTag](
sc: SparkContext,
@transient private val data: Seq[T],
numSlices: Int,
locationPrefs: Map[Int, Seq[String]])
extends RDD[T](sc, Nil)
-
可以发现,它是通过继承RDD抽象类来创建的。
然后跟踪一下默认参数numSlices=defaultParallelism到底依据什么进行设置的。
点进入defaultParallelism,发现它是一个方法,返回taskScheduler.defaultParallelism
def defaultParallelism: Int = {
assertNotStopped()
taskScheduler.defaultParallelism
} -
继续点进入,此时发现它是一个抽象方法
def defaultParallelism(): Int
-
我们通过 ctrl+h 查找当前抽象类的子类,查找子类中的defaultParallelism方法
override def defaultParallelism(): Int = backend.defaultParallelism() -
发现它依赖了backend.defaultParallelism(),继续点进去
def defaultParallelism(): Int -
又是一个抽象方法,继续找它的子类的重写后的方法
override def defaultParallelism(): Int =
scheduler.conf.getInt("spark.default.parallelism", totalCores)
- 此时可以看到它是通过调用scheduler.conf.getInt方法传递两个参数获得
def getInt(key: String, defaultValue: Int): Int = catchIllegalValue(key) {
getOption(key).map(_.toInt).getOrElse(defaultValue)
}
- 在getInt方法内部可以看出来,它是通过传入的第一个参数key来查哈希表,如果查不到以第二个参数值作为分区数量。
- 继续看getOption方法内部信息
def getOption(key: String): Option[String] = {
Option(settings.get(key)).orElse(getDeprecatedConfig(key, settings))
}
- 可以看到查询的map是settings,看看settings是什么玩意
class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Serializable {
import SparkConf._
/** Create a SparkConf that loads defaults from system properties and the classpath */
def this() = this(true)
private val settings = new ConcurrentHashMap[String, String]()
- 然后可以看到settings 是一个ConcurrentHashMap对象,它是SparkConf类的一个成员变量
- 我们再看看settings 中都有什么元素,看看本类中都有哪些方法可以为setting添加元素
private[spark] def set(key: String, value: String, silent: Boolean): SparkConf = {
settings.put(key, value)
this
}
def setMaster(master: String): SparkConf = {
set("spark.master", master)
}
/** Set a name for your application. Shown in the Spark web UI. */
def setAppName(name: String): SparkConf = {
set("spark.app.name", name)
}
- 这个时候就很明确了,分区的数量就是通过我们创建SparkConf配置类的时候通过spark.default.parallelism进行设置的,如果没有设置就默认通过totalCores设置。
override def defaultParallelism(): Int =
scheduler.conf.getInt("spark.default.parallelism", totalCores)
- 不过totalCores这个值是多少呢?
private[spark] class LocalSchedulerBackend(
conf: SparkConf,
scheduler: TaskSchedulerImpl,
val totalCores: Int)
- 可以看到totalCores是作为LocalSchedulerBackend这个类的构造方法参数传进入的。那么就说明这个值是在创建LocalSchedulerBackend类的实例对象的时候传入的。
- 在SparkContext.scala这个类的createTaskScheduler方法有new LocalSchedulerBackend(sc.getConf, scheduler, 1)
private def createTaskScheduler(
sc: SparkContext,
master: String,
deployMode: String): (SchedulerBackend, TaskScheduler) = {
master match {
case "local" =>
val scheduler = new TaskSchedulerImpl(sc, MAX_LOCAL_TASK_FAILURES, isLocal = true)
val backend = new LocalSchedulerBackend(sc.getConf, scheduler, 1)
scheduler.initialize(backend)
(backend, scheduler)
- 如果master匹配 “local” ,我们分配分区数为1,如果不匹配,执行下面的匹配
case LOCAL_N_REGEX(threads) =>
def localCpuCount: Int = Runtime.getRuntime.availableProcessors()
// local[*] estimates the number of cores on the machine; local[N] uses exactly N threads.
val threadCount = if (threads == "*") localCpuCount else threads.toInt
if (threadCount <= 0) {
throw new SparkException(s"Asked to run locally with $threadCount threads")
}
val scheduler = new TaskSchedulerImpl(sc, MAX_LOCAL_TASK_FAILURES, isLocal = true)
val backend = new LocalSchedulerBackend(sc.getConf, scheduler, threadCount)
scheduler.initialize(backend)
(backend, scheduler)
- 它匹配的是LOCAL_N_REGEX(threads)这个函数,我们看看函数内部是什么
val LOCAL_N_REGEX = """local\[([0-9]+|\*)\]""".r
- 可以看到threads匹配的是0-9或者*,根据如下代码可以看出,如果是数字就分配具体的数字,如果是*,那么传入localCpuCount
val threadCount = if (threads == "*") localCpuCount else threads.toInt
-
我们再看看localCpuCount 的值是多少,它依赖于Runtime.getRuntime.availableProcessors()
def localCpuCount: Int = Runtime.getRuntime.availableProcessors() -
继续点进入,可以发现它调用的是本地native方法,方法名称为可利用的线程数量,其实就是当前cpu核数。
public native int availableProcessors();
- 总之,分区数量依赖于setMaster配置,如果是local分配一个分区数,如果是*分区数取决于cpu核心数量,如果是数字依据具体数字分配。
new SparkConf().setMaster("local[*]")
textFile方法源码分析
- textFile方法用于通过文件创建rdd
val rdd2 = sc.textFile("E:\\projectstudy\\sparkstudy\\aaa.txt")
- 看一下方法源码,可以发现该方法除了文件路径还传入了默认最小分区数,返回的是hadoopFile方法的返回值
def textFile(
path: String,
minPartitions: Int = defaultMinPartitions): RDD[String] = withScope {
assertNotStopped()
hadoopFile(path, classOf[TextInputFormat], classOf[LongWritable], classOf[Text],
minPartitions).map(pair => pair._2.toString).setName(path)
}
- 先看看hadoopFile方法,可以发现该方法返回了一个HadoopRDD实例对象
def hadoopFile[K, V](
path: String,
inputFormatClass: Class[_ <: InputFormat[K, V]],
keyClass: Class[K],
valueClass: Class[V],
minPartitions: Int = defaultMinPartitions): RDD[(K, V)] = withScope {
assertNotStopped()
val confBroadcast = broadcast(new SerializableConfiguration(hadoopConfiguration))
val setInputPathsFunc = (jobConf: JobConf) => FileInputFormat.setInputPaths(jobConf, path)
new HadoopRDD(
this,
confBroadcast,
Some(setInputPathsFunc),
inputFormatClass,
keyClass,
valueClass,
minPartitions).setName(path)
}
- 看一下HadoopRDD这个类,HadoopRDD是通过继承RDD自定义的一个类,它重写了分区函数,计算函数等方法。
class HadoopRDD[K, V](
sc: SparkContext,
broadcastedConf: Broadcast[SerializableConfiguration],
initLocalJobConfFuncOpt: Option[JobConf => Unit],
inputFormatClass: Class[_ <: InputFormat[K, V]],
keyClass: Class[K],
valueClass: Class[V],
minPartitions: Int)
extends RDD[(K, V)](sc, Nil) with Logging
-
再看看默认最小分区数defaultMinPartitions是怎么来的,可以看到它是默认并发量和2的较小值。
def defaultMinPartitions: Int = math.min(defaultParallelism, 2) -
而默认并发量在上一个例子解析parallelize方法时已经看到,它是根据master参数指定的。
parallelize方法分区函数解析
通过本地集合创建rdd并指定三个分区,然后写入文件
val rdd1 = sc.parallelize(List(1,2,3,4,5),3)
rdd1.saveAsTextFile("E:\\projectstudy\\sparkstudy\\out")
-
可以看到每个分区数据如下
分区1:1
分区2:2 3
分区3:4 5 -
现在刨析一下创建rdd时具体如何分区的
-
parallelize方法内容如下,返回了一个ParallelCollectionRDD对象
def parallelize[T: ClassTag](
seq: Seq[T],
numSlices: Int = defaultParallelism): RDD[T] = withScope {
assertNotStopped()
new ParallelCollectionRDD[T](this, seq, numSlices, Map[Int, Seq[String]]())
}
- 那么分区的逻辑应该在ParallelCollectionRDD类中通过重写RDD抽象类的方法实现
- 先看看RDD抽象类都有哪些方法,毫无疑问,子类需要重写getPartitions方法来实现分区逻辑
def compute(split: Partition, context: TaskContext): Iterator[T]
/**
* Implemented by subclasses to return the set of partitions in this RDD. This method will only
* be called once, so it is safe to implement a time-consuming computation in it.
*
* The partitions in this array must satisfy the following property:
* `rdd.partitions.zipWithIndex.forall { case (partition, index) => partition.index == index }`
*/
protected def getPartitions: Array[Partition]
protected def getDependencies: Seq[Dependency[_]] = deps
protected def getPreferredLocations(split: Partition): Seq[String] = Nil
- 然后查看子类getPartitions方法,可以看到每个切片就是通过其中slice方法得到
override def getPartitions: Array[Partition] = {
val slices = ParallelCollectionRDD.slice(data, numSlices).toArray
slices.indices.map(i => new ParallelCollectionPartition(id, i, slices(i))).toArray
}
- 然后看看slice方法内部逻辑,它传入了一个seq序列和一个分区数,将seq进行匹配
def slice[T: ClassTag](seq: Seq[T], numSlices: Int): Seq[Seq[T]] = {
seq match {
case r: Range =>
positions(r.length, numSlices).zipWithIndex.map { case ((start, end), index) =>
// If the range is inclusive, use inclusive range for the last slice
if (r.isInclusive && index == numSlices - 1) {
new Range.Inclusive(r.start + start * r.step, r.end, r.step)
}
else {
new Range(r.start + start * r.step, r.start + end * r.step, r.step)
}
}.toSeq.asInstanceOf[Seq[Seq[T]]]
case nr: NumericRange[_] =>
// For ranges of Long, Double, BigInteger, etc
val slices = new ArrayBuffer[Seq[T]](numSlices)
var r = nr
for ((start, end) <- positions(nr.length, numSlices)) {
val sliceSize = end - start
slices += r.take(sliceSize).asInstanceOf[Seq[T]]
r = r.drop(sliceSize)
}
slices
case _ =>
val array = seq.toArray // To prevent O(n^2) operations for List etc
positions(array.length, numSlices).map { case (start, end) =>
array.slice(start, end).toSeq
}.toSeq
}
- 最终走向了如下代码
val array = seq.toArray // To prevent O(n^2) operations for List etc
positions(array.length, numSlices).map { case (start, end) =>
array.slice(start, end).toSeq
}.toSeq
- 然后看看positions(array.length, numSlices)语句,它根据每个分区编号开始遍历,得到一组起止位置的下标,每个分区都会对应一个起止位置
def positions(length: Long, numSlices: Int): Iterator[(Int, Int)] = {
(0 until numSlices).iterator.map { i =>
val start = ((i * length) / numSlices).toInt
val end = (((i + 1) * length) / numSlices).toInt
(start, end)
}
}
- 针对每个分区,根据当前分区的起止位置,调用数组array的切片方法得到每个分区的元素并转换为seq序列。
array.slice(start, end).toSeq
- 最后看一下计算函数compute,它就是把每个分区拿出来进行遍历。
override def compute(s: Partition, context: TaskContext): Iterator[T] = {
new InterruptibleIterator(context, s.asInstanceOf[ParallelCollectionPartition[T]].iterator)
}
textFile方法分区函数解析
- textFile读取文件创建rdd,指定最小分区数为3,文件内容为 aaa bbb ccc ddd eee
val rdd1 = sc.textFile("E:\\projectstudy\\sparkstudy\\aaa.txt",3)
-
此时rdd实际分了四个分区,每个分区内容如下
分区1:aaa bbb
分区2:ccc
分区3:ddd eee
分区4: -
现在看一下源代码,分析实际的分区规则
-
首先textFile方法内容如下,调用一个hadoopFile方法
def textFile(
path: String,
minPartitions: Int = defaultMinPartitions): RDD[String] = withScope {
assertNotStopped()
hadoopFile(path, classOf[TextInputFormat], classOf[LongWritable], classOf[Text],
minPartitions).map(pair => pair._2.toString).setName(path)
}
- hadoopFile方法返回了一个HadoopRDD对象
def hadoopFile[K, V](
path: String,
inputFormatClass: Class[_ <: InputFormat[K, V]],
keyClass: Class[K],
valueClass: Class[V],
minPartitions: Int = defaultMinPartitions): RDD[(K, V)] = withScope {
val confBroadcast = broadcast(new SerializableConfiguration(hadoopConfiguration))
val setInputPathsFunc = (jobConf: JobConf) => FileInputFormat.setInputPaths(jobConf, path)
new HadoopRDD(
this,
confBroadcast,
Some(setInputPathsFunc),
inputFormatClass,
keyClass,
valueClass,
minPartitions).setName(path)
}
- HadoopRDD是通过继承RDD自定义的一个类,那么HadoopRDD肯定要重写计算函数 分区函数等方法。
class HadoopRDD[K, V](
sc: SparkContext,
broadcastedConf: Broadcast[SerializableConfiguration],
initLocalJobConfFuncOpt: Option[JobConf => Unit],
inputFormatClass: Class[_ <: InputFormat[K, V]],
keyClass: Class[K],
valueClass: Class[V],
minPartitions: Int)
extends RDD[(K, V)](sc, Nil) with Logging
- HadoopRDD的分区函数如下,其中getSplits获取切片方法是真正的分区逻辑
override def getPartitions: Array[Partition] = {
val jobConf = getJobConf()
// add the credentials here as this can be called before SparkContext initialized
SparkHadoopUtil.get.addCredentials(jobConf)
val allInputSplits = getInputFormat(jobConf).getSplits(jobConf, minPartitions)
val inputSplits = if (ignoreEmptySplits) {
allInputSplits.filter(_.getLength > 0)
} else {
allInputSplits
}
val array = new Array[Partition](inputSplits.size)
for (i <- 0 until inputSplits.size) {
array(i) = new HadoopPartition(id, i, inputSplits(i))
}
array
}
- 然后继续跟踪getSplits方法,发现它是一个抽象方法
public interface InputFormat<K, V> {
InputSplit[] getSplits(JobConf var1, int var2) throws IOException;
RecordReader<K, V> getRecordReader(InputSplit var1, JobConf var2, Reporter var3) throws IOException;
}
- ctrl+h继续跟踪子类的getSplits方法
public InputSplit[] getSplits(JobConf job, int numSplits) throws IOException {
Stopwatch sw = new Stopwatch().start();
//获取当前目录的所有文件
FileStatus[] files = listStatus(job);
// Save the number of input files for metrics/loadgen
job.setLong(NUM_INPUT_FILES, files.length);
//记录文件的总大小
long totalSize = 0; // compute total size
//遍历所有文件,计算文件总大小
for (FileStatus file: files) { // check we have valid files
if (file.isDirectory()) {
throw new IOException("Not a file: "+ file.getPath());
}
totalSize += file.getLen();
}
//总大小除以最小分区数,得到每个分区预计大小
long goalSize = totalSize / (numSplits == 0 ? 1 : numSplits);
//获取每个分区最小大小
long minSize = Math.max(job.getLong(org.apache.hadoop.mapreduce.lib.input.
FileInputFormat.SPLIT_MINSIZE, 1), minSplitSize);
// 记录要返回的切片记录
ArrayList<FileSplit> splits = new ArrayList<FileSplit>(numSplits);
NetworkTopology clusterMap = new NetworkTopology();
//再次遍历所有文件
for (FileStatus file: files) {
Path path = file.getPath(); //当前文件路径
long length = file.getLen(); //当前文件大小
if (length != 0) { //如果当前文件不为空
FileSystem fs = path.getFileSystem(job); //获取FileSystem 对象
BlockLocation[] blkLocations; //获取块信息
if (file instanceof LocatedFileStatus) {
blkLocations = ((LocatedFileStatus) file).getBlockLocations();
} else {
blkLocations = fs.getFileBlockLocations(file, 0, length);
}
//如果当前文件是可分片的
if (isSplitable(fs, path)) {
long blockSize = file.getBlockSize(); //获取当前文件的块大小(默认128m)
//根据每个分区预计大小,每个分区最小大小,块大小 计算最终每个分区的大小
long splitSize = computeSplitSize(goalSize, minSize, blockSize);
//把当前文件总的字节数量赋值给bytesRemaining
long bytesRemaining = length;
//判断 bytesRemaining 除以切片数量 的值是否大于SPLIT_SLOP(1.1),如果大于进入循环
while (((double) bytesRemaining)/splitSize > SPLIT_SLOP) {
//然后进行文件的切片规划,包括对哪个文件从什么位置开始,切多少字节
splits.add(makeSplit(path, length-bytesRemaining, splitSize,
splitHosts[0], splitHosts[1]));
//bytesRemaining 减去当前切片大小继续下一次循环
bytesRemaining -= splitSize;
}
// 如果最后还有剩余内容,将剩余信息单独作为一个分片保存
//比如10个字节指定最小分区为3,最后肯定要剩一个字节单独作为一个分区
if (bytesRemaining != 0) {
splits.add(makeSplit(path, length - bytesRemaining, bytesRemaining,
splitHosts[0], splitHosts[1]));
}
} else {
String[][] splitHosts = getSplitHostsAndCachedHosts(blkLocations,0,length,clusterMap);
splits.add(makeSplit(path, 0, length, splitHosts[0], splitHosts[1]));
}
} else {
//Create empty hosts array for zero length files
splits.add(makeSplit(path, 0, length, new String[0]));
}
}
//将最终的切片规划返回
return splits.toArray(new FileSplit[splits.size()]);
}
- 得到了分片规划后,根据规划信息创建分区,一个HadoopPartition对象就是一个分区
override def getPartitions: Array[Partition] = {
val jobConf = getJobConf()
// add the credentials here as this can be called before SparkContext initialized
SparkHadoopUtil.get.addCredentials(jobConf)
val allInputSplits = getInputFormat(jobConf).getSplits(jobConf, minPartitions)
val array = new Array[Partition](inputSplits.size)
for (i <- 0 until inputSplits.size) {
array(i) = new HadoopPartition(id, i, inputSplits(i))
}
array
}
- getPartitions方法只是把分区创建出来了(创建了一个数组里面是一组partition对象),但是每个分区具体放哪些数据由compute方法决定
- compute方法才是真正读取文件数据到分区中,该方法会根据一组partition对象返回一个迭代器集合
compute方法重点代码如下,其内部调用了inputFormat 的getRecordReader方法
override def compute(theSplit: Partition, context: TaskContext): InterruptibleIterator[(K, V)] = {
val iter = new NextIterator[(K, V)] {
private val split = theSplit.asInstanceOf[HadoopPartition]
private var reader: RecordReader[K, V] = null
private val inputFormat = getInputFormat(jobConf)
reader =
try {
inputFormat.getRecordReader(split.inputSplit.value, jobConf, Reporter.NULL)
} catch {
}
// Register an on-task-completion callback to close the input stream.
context.addTaskCompletionListener[Unit] { context =>
// Update the bytes read before closing is to make sure lingering bytesRead statistics in
// this thread get correctly added.
updateBytesRead()
closeIfNeeded()
}
private val key: K = if (reader == null) null.asInstanceOf[K] else reader.createKey()
private val value: V = if (reader == null) null.asInstanceOf[V] else reader.createValue()
override def getNext(): (K, V) = {
try {
finished = !reader.next(key, value)
} catch {
}
(key, value)
}
new InterruptibleIterator[(K, V)](context, iter)
}
- getRecordReader是一个抽象方法
RecordReader<K, V> getRecordReader(InputSplit split,
JobConf job,
Reporter reporter) throws IOException;
- 寻找子类TextInputFormat中的getRecordReader方法,可以看到内部返回了一个LineRecordReader对象
public RecordReader<LongWritable, Text> getRecordReader(
InputSplit genericSplit, JobConf job,
Reporter reporter)
throws IOException {
reporter.setStatus(genericSplit.toString());
String delimiter = job.get("textinputformat.record.delimiter");
byte[] recordDelimiterBytes = null;
if (null != delimiter) {
recordDelimiterBytes = delimiter.getBytes(Charsets.UTF_8);
}
return new LineRecordReader(job, (FileSplit) genericSplit,
recordDelimiterBytes);
}
在创建LineRecordReader对象时会传入分区规划split,最终读取数据时会根据分区规划进行读取
public LineRecordReader(Configuration job, FileSplit split,
byte[] recordDelimiter) throws IOException {
this.maxLineLength = job.getInt(org.apache.hadoop.mapreduce.lib.input.
LineRecordReader.MAX_LINE_LENGTH, Integer.MAX_VALUE);
start = split.getStart();
end = start + split.getLength();
final Path file = split.getPath();
compressionCodecs = new CompressionCodecFactory(job);
codec = compressionCodecs.getCodec(file);
- 不过需要注意的是LineRecordReader在读取时是一行一行读取的。
public synchronized boolean next(LongWritable key, Text value)
throws IOException {
while (getFilePosition() <= end || in.needAdditionalRecordAfterSplit()) {
key.set(pos);
int newSize = 0;
newSize = in.readLine(value, maxLineLength, maxBytesToConsume(pos));
pos += newSize;
if (newSize == 0) {
return false;
}
if (newSize < maxLineLength) {
return true;
}
}
return false;
}
- 假设现在有如下数据
0123456789abcde\r\n
fgx\r\n
hi
- 使用断点调试debug程序,中间运行过程如下,可以发现在分区规划时很公平给每个分区分配了8个字节
totalSize=24 //文件大小(字节)
goalSize=8 //每个分区预计大小
minSize=1 //每个分区最小大小
numSplits=3 //分区数量
splitSize=8 //每个分区实际大小
//每个分区实际规划
0 = {FileSplit@5433} "file:/E:/projectstudy/sparkstudy/aaa.txt:0+8"
1 = {FileSplit@5434} "file:/E:/projectstudy/sparkstudy/aaa.txt:8+8"
2 = {FileSplit@5471} "file:/E:/projectstudy/sparkstudy/aaa.txt:16+8"
- 但是实际分区数据并不是这样
分区1: 0123456789abcde\r\n
分区2:空
分区3:fgx\r\n
hi
- 因为在真正读取数据时默认的LineRecordReader是按行读的,虽然第一个分区分配0-8 8个字节,但是在读取时会将整行信息全给分区1,轮到分区2时,分区2分配的是8-16,但是当前偏移量已经来到17,所以分区2没有数据可以分配,那么在读取时把第二行数据给分区3,此时分区3的偏移量还没用完,继续把第三行数据分配给分区3。