Spark--默认创建RDD的分区个数规则

Spark–默认创建RDD的分区个数规则

创建RDD的三种方式

1)从集合(内存)中创建 方法:parallelize、makeRDD

2)从外部存储中创建 方法:testFile

3)从其他RDD中创建(执行转换算子的时候)

1)从集合(内存)中创建 方法:parallelize、makeRDD

1、首先来看一下这种方式创建的RDD是怎样的分区规则
代码:

object test02_RDDDefalutPatirion {

  def main(args: Array[String]): Unit = {

    val conf = new SparkConf().setAppName("RDDDefalutPrtition").setMaster("local[*]")
    val sc = new SparkContext(conf)

    val listRDD: RDD[Int] = sc.makeRDD(List(1,2,3,4))


    println(listRDD.partitions.size)

    listRDD.saveAsTextFile("E:\\IDEAworkspace\\bigdata-MrG\\spark-2021\\output")


    sc.stop()

  }

}

首先我们来看一下运行结果: 8
也就是出现了8个分区,那这8个分区是怎么出现的呢?
2、我们先追到makeRDD方法里边去

  def makeRDD[T: ClassTag](
      seq: Seq[T],
      numSlices: Int = defaultParallelism): RDD[T] = withScope {
    parallelize(seq, numSlices)
  }

可以看到makeRDD方法实际上有两个参数,seq和numSlices。
seq:也就是我们传入的集合
numSlices:也就是切片数量,分区数量有一个默认值defaultParallelism

顺便一提,makeRDD在底层也是调用的parallelize方法。
所以我们现在要看分区规则,就需要再往下跟defaultParallelism;
3、defaultParallelism的代码如下:

  /** Default level of parallelism to use when not given by user (e.g. parallelize and makeRDD). */
  def defaultParallelism: Int = {
    assertNotStopped()
    taskScheduler.defaultParallelism
  }

可以看到defaultParallelism的返回值是taskScheduler.defaultParallelism,所以还得继续往下跟
4、taskScheduler.defaultParallelism的代码如下:

  // Get the default level of parallelism to use in the cluster, as a hint for sizing jobs.
  def defaultParallelism(): Int

在这里是一个特质的抽象方法,所以我们需要找他的实现
实现如下:

 override def defaultParallelism(): Int = backend.defaultParallelism()

继续跟代码,还是一个特质里的抽象方法:

def defaultParallelism(): Int

所以还得找实现, 这边找到两个继承类,由于我们是在本地执行的代码。所以看LocalSchedulerBackend.scala这个类
在这里插入图片描述

  override def defaultParallelism(): Int =
    scheduler.conf.getInt("spark.default.parallelism", totalCores)

然后我们还是往下跟这个源代码,getInt方法如下:

  /** Get a parameter as an integer, falling back to a default if not set */
  def getInt(key: String, defaultValue: Int): Int = {
    getOption(key).map(_.toInt).getOrElse(defaultValue)
  }

这个代码的逻辑就十分显而易见了,获取key的值,并转成Int类型,如果没有这个值,那么取defaultValue的值。
由于前边那个key的值是spark.default.parallelism我们并没有设置,所以这里取得是defaultValue的值,而上边的方法是scheduler.conf.getInt(“spark.default.parallelism”, totalCores),所以继续往下看totalCores的代码
5、totalCores的代码如下:

private[spark] class LocalSchedulerBackend(
    conf: SparkConf,
    scheduler: TaskSchedulerImpl,
    val totalCores: Int)

totalCores是在创建LocalSchedulerBackend的时候的参数,所以我们要看一下LocalSchedulerBackend的调用方法传入了什么参数,所以我们回到SparkContext里查看
6、找到LocalSchedulerBackend在SparkContext的调用如下,需要根据master的情况来判定:

master match {
      case "local" =>
        val scheduler = new TaskSchedulerImpl(sc, MAX_LOCAL_TASK_FAILURES, isLocal = true)
        val backend = new LocalSchedulerBackend(sc.getConf, scheduler, 1)
        scheduler.initialize(backend)
        (backend, scheduler)

      case LOCAL_N_REGEX(threads) =>
        def localCpuCount: Int = Runtime.getRuntime.availableProcessors()
        // local[*] estimates the number of cores on the machine; local[N] uses exactly N threads.
        val threadCount = if (threads == "*") localCpuCount else threads.toInt
        if (threadCount <= 0) {
          throw new SparkException(s"Asked to run locally with $threadCount threads")
        }
        val scheduler = new TaskSchedulerImpl(sc, MAX_LOCAL_TASK_FAILURES, isLocal = true)
        val backend = new LocalSchedulerBackend(sc.getConf, scheduler, threadCount)
        scheduler.initialize(backend)
        (backend, scheduler)

      case LOCAL_N_FAILURES_REGEX(threads, maxFailures) =>
        def localCpuCount: Int = Runtime.getRuntime.availableProcessors()
        // local[*, M] means the number of cores on the computer with M failures
        // local[N, M] means exactly N threads with M failures
        val threadCount = if (threads == "*") localCpuCount else threads.toInt
        val scheduler = new TaskSchedulerImpl(sc, maxFailures.toInt, isLocal = true)
        val backend = new LocalSchedulerBackend(sc.getConf, scheduler, threadCount)
        scheduler.initialize(backend)
        (backend, scheduler)

      case SPARK_REGEX(sparkUrl) =>
        val scheduler = new TaskSchedulerImpl(sc)
        val masterUrls = sparkUrl.split(",").map("spark://" + _)
        val backend = new StandaloneSchedulerBackend(scheduler, sc, masterUrls)
        scheduler.initialize(backend)
        (backend, scheduler)

      case LOCAL_CLUSTER_REGEX(numSlaves, coresPerSlave, memoryPerSlave) =>
        // Check to make sure memory requested <= memoryPerSlave. Otherwise Spark will just hang.
        val memoryPerSlaveInt = memoryPerSlave.toInt
        if (sc.executorMemory > memoryPerSlaveInt) {
          throw new SparkException(
            "Asked to launch cluster with %d MB RAM / worker but requested %d MB/worker".format(
              memoryPerSlaveInt, sc.executorMemory))
        }
        val scheduler = new TaskSchedulerImpl(sc)
        val localCluster = new LocalSparkCluster(
          numSlaves.toInt, coresPerSlave.toInt, memoryPerSlaveInt, sc.conf)
        val masterUrls = localCluster.start()
        val backend = new StandaloneSchedulerBackend(scheduler, sc, masterUrls)
        scheduler.initialize(backend)
        backend.shutdownCallback = (backend: StandaloneSchedulerBackend) => {
          localCluster.stop()
        }
        (backend, scheduler)

      case masterUrl =>
        val cm = getClusterManager(masterUrl) match {
          case Some(clusterMgr) => clusterMgr
          case None => throw new SparkException("Could not parse Master URL: '" + master + "'")
        }
        try {
          val scheduler = cm.createTaskScheduler(sc, masterUrl)
          val backend = cm.createSchedulerBackend(sc, masterUrl, scheduler)
          cm.initialize(scheduler, backend)
          (backend, scheduler)
        } catch {
          case se: SparkException => throw se
          case NonFatal(e) =>
            throw new SparkException("External scheduler cannot be instantiated", e)
        }
    }

通过case的各种情况追到源码如下:

private object SparkMasterRegex {
  // Regular expression used for local[N] and local[*] master formats
  val LOCAL_N_REGEX = """local\[([0-9]+|\*)\]""".r
  // Regular expression for local[N, maxRetries], used in tests with failing tasks
  val LOCAL_N_FAILURES_REGEX = """local\[([0-9]+|\*)\s*,\s*([0-9]+)\]""".r
  // Regular expression for simulating a Spark cluster of [N, cores, memory] locally
  val LOCAL_CLUSTER_REGEX = """local-cluster\[\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*]""".r
  // Regular expression for connecting to Spark deploy clusters
  val SPARK_REGEX = """spark://(.*)""".r
}

由于我们创建SparkContext的时候,SetMaster是local[*],所以匹配的是第二种情况:

case LOCAL_N_REGEX(threads) =>
        def localCpuCount: Int = Runtime.getRuntime.availableProcessors()
        // local[*] estimates the number of cores on the machine; local[N] uses exactly N threads.
        val threadCount = if (threads == "*") localCpuCount else threads.toInt
        if (threadCount <= 0) {
          throw new SparkException(s"Asked to run locally with $threadCount threads")
        }
        val scheduler = new TaskSchedulerImpl(sc, MAX_LOCAL_TASK_FAILURES, isLocal = true)
        val backend = new LocalSchedulerBackend(sc.getConf, scheduler, threadCount)
        scheduler.initialize(backend)
        (backend, scheduler)

实际上我们要的代码是
val backend = new LocalSchedulerBackend(sc.getConf, scheduler, threadCount),而要确定的位置是threadCount参数的具体值,所以要看val threadCount = if (threads == "*") localCpuCount else threads.toInt这行代码,根据我们传入的情况,local[ * ],所以得到的threadCount 是 localCpuCount ,那我们继续看localCpuCount 的值是 def localCpuCount: Int = Runtime.getRuntime.availableProcessors(),我们继续看availableProcessors的代码:

    public native int availableProcessors();

这个返回的就是当前cpu的核数;
在这里插入图片描述
我的cpu是8核的,所以这边的默认分区数就是8个。
假如我这边把local[*]改成local[4],那默认的分区数就是4了。

2)从外部存储中创建 方法:testFile

1、再来看一下从外部存储中创建RDD的默认分区方式:
代码如下:

object test02_RDDDefalutPatirion {

  def main(args: Array[String]): Unit = {

    val conf = new SparkConf().setAppName("RDDDefalutPrtition").setMaster("local[4]")
    val sc = new SparkContext(conf)

    val listRDD: RDD[Int] = sc.makeRDD(List(1,2,3,4))


    val fileRDD: RDD[String] = sc.textFile("E:\\IDEAworkspace\\bigdata-MrG\\spark-2021\\input")
   // println(listRDD.partitions.size)
    println(fileRDD.partitions.size)

    //listRDD.saveAsTextFile("E:\\IDEAworkspace\\bigdata-MrG\\spark-2021\\output")


    sc.stop()

  }

}

输出结果是 2
2、我们跟进textFile的源码,如下:

  /**
   * Read a text file from HDFS, a local file system (available on all nodes), or any
   * Hadoop-supported file system URI, and return it as an RDD of Strings.
   */
  def textFile(
      path: String,
      minPartitions: Int = defaultMinPartitions): RDD[String] = withScope {
    assertNotStopped()
    hadoopFile(path, classOf[TextInputFormat], classOf[LongWritable], classOf[Text],
      minPartitions).map(pair => pair._2.toString).setName(path)
  }

可以看到也是传入两个参数 path和minPartitions
path:也就是我们传入的文件路径
minPartitions: 也就是最小分区数,默认值是defaultMinPartitions
3、defaultMinPartitions的代码如下:

def defaultMinPartitions: Int = math.min(defaultParallelism, 2)

意思也就是从defaultParallelism和2里边取一个最小值,defaultParallelism的取值我们上边已经介绍过了,根据我这边的配置是8,所以取较小,是2

  • 3
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值