Spark RPC源码剖析

Spark RPC层是基于通信框架Netty开发的

核心类:RpcEnv,RpcEndPoint,RpcEndPointRef

比如Spark Master和Worker便是一个RpcEndPoint(通信端),只能通过RpcEnv获取与RpcEndPoint通信的对象RpcEndPointRef,Spark默认使用更高效的NettyRpcEnv

客户端调用RpcEndPointRef发送消息,首先通过RpcEnv处理消息,找到该消息应该发送给谁,则路由到RpcEndPoint

 

1.    RpcEnv:是Rpc的环境变量,管理整个RpcEndPoint的生命周期,从注册RpcEndPoint,到消息处理,最后到停用RpcEndPoint;RpcEnv只能通过RpcEnvFactory创建;RpcEnv的核心方法:

def setupEndPoint(name:String,endpoint:RpcEndPoint):RpcEndPointRef //向RpcEnv中注册RpcEndPoint,管理者RpcEndPoint到RpcEndPointRef的绑定关系

2 RpcEndPoint:定义了RPC通信过程的对象,具有管理RpcEndPoint的生命周期的操作(constructor->onStart->receive->onStop),并给出了RpcEndPoint基于事件驱动的行为(连接,断开,网络异常),RpcEndPoint的核心是接收消息并处理

def  receive:PartialFunction[Any,Unit]={

    case_ =>thrownewSparkException(self +" does not implement 'receive'")

}

def  receiveAndReply(context:RpcCallContext):PartialFunction[Any,Unit]={

  case_ =>context.sendFailure(newSparkException(self +" won't reply anything"))}

RpcEndPoint.receive接收由RpcEndPointRef.send()发送的消息,直接由RpcEndPoint处理

RpcEndPoint.receiveAndReply()接收RpcEndPointRef.ask发送的消息,RpcEndPoint处理后,需要给RpcEndPointRef.ask发送响应消息

 

3.RpcEndPointRef

 RpcEndPointRef是RpcEndPoint的远程对象类似于远程代理

private[spark] abstract class RpcEndpointRef(conf: SparkConf)  extends Serializable with Logging { 
  private[this] val maxRetries = RpcUtils.numRetries(conf)
  private[this] val retryWaitMs = RpcUtils.retryWaitMs(conf)
  private[this] val defaultAskTimeout = RpcUtils.askRpcTimeout(conf)
  def address: RpcAddress
  def name: String
  def send(message: Any): Unit
  def ask[T: ClassTag](message: Any, timeout: RpcTimeout): Future[T]
  def ask[T: ClassTag](message: Any): Future[T] = ask(message, defaultAskTimeout)
  def askWithRetry[T: ClassTag](message: Any): T = askWithRetry(message, defaultAskTimeout)
  def askWithRetry[T: ClassTag](message: Any, timeout: RpcTimeout): T = {
    ... ...//通过Future来异步获取响应结果
  }
}

4.Spark Driver Env创建NettyRpcEnv

private[rpc] class NettyRpcEnvFactory extends RpcEnvFactory with Logging {
  def create(config: RpcEnvConfig): RpcEnv = {
    val sparkConf = config.conf
    //创建序列化
    val javaSerializerInstance = new JavaSerializer(sparkConf).newInstance().asInstanceOf[JavaSerializerInstance]
    //new 一个NettyRpcEnv实例
    val nettyEnv =  new NettyRpcEnv(sparkConf, javaSerializerInstance, config.host, config.securityManager)
    if (!config.clientMode) {//非客户端模式
      val startNettyRpcEnv: Int => (NettyRpcEnv, Int) = {
        actualPort => nettyEnv.startServer(actualPort)//启动服务器
        (nettyEnv, nettyEnv.address.port)
      }
      try {
        // 根据指定的端口号和主机,启动Driver Rpc服务
        Utils.startServiceOnPort(config.port, startNettyRpcEnv, sparkConf, config.name)._1
      }
      catch {
        case NonFatal(e) =>
          nettyEnv.shutdown()
          throw e
      }
    }
  nettyEnv
  }
}
由于NettyRpcEnv只能由NettyRpcEnvFactory.create()创建,其流程如下
1)    得到Spark Conf
2)    创建序列化
3)    创建NettyRpcEnv
4)    判断不为客户端模式,创建startNettyRpcEnv
5)    Utils.startServiceOnPort启动Driver Rpc服务
 
NettyRpcEnv创建的主要流程:
1)private[netty] val transportConf = SparkTransportConf.fromSparkConf(conf.clone.set("spark.rpc.io.numConnectionsPerPeer", "1"), "rpc", conf.getInt("spark.rpc.io.threads", 0))
//根据Spark Conf初始化创建transportConf
2)  private val dispatcher: Dispatcher = new Dispatcher(this)
//创建Dispatcher,主要用于消息的分发处理
3)  private val streamManager = new NettyStreamManager(this)
//创建streamManager
4)  private val transportContext = new TransportContext(transportConf, new NettyRpcHandler(dispatcher, this, streamManager))
//创建TransportContext用于创建Client和Server
5)  private def createClientBootstraps(): 
//对权限判断并根据SecurityManager和TransportConf创建一组SaslClientBootstrap
// 声明一个clientFactory,用户创建通信的客户端
  private val clientFactory = transportContext.createClientFactory(createClientBootstraps())  
6)   private val clientFactory = transportContext.createClientFactory(createClientBootstraps())//创建客户端工厂,用于通信的客户端只能由客户端工厂创建,客户端工厂用于文件下载
7) val timeoutScheduler = ThreadUtils.newDaemonSingleThreadScheduledExecutor("netty-rpc-env-timeout")
//创建一个netty-rpc-env-timeout的守护线程,监控客户端到服务器的连接
8)private[netty] val clientConnectionExecutor = ThreadUtils.newDaemonCachedThreadPool( "netty-rpc-connection",    conf.getInt("spark.rpc.connect.threads", 64))
//创建客户端连接执行器,设置Rpc连接的线程数
9) private val outboxes = new ConcurrentHashMap[RpcAddress, Outbox]()  
//创建并发哈希表记录远程Rpc地址到Outbox的映射,消息放入Outbox可以实现非阻塞send
10)    def startServer(port: Int): Unit = {    
    //通过transportContext启动通信底层的服务端
    server = transportContext.createServer(host, port, bootstraps)
    //注册一个RpcEndpointVerifier,对Server进行验证
    dispatcher.registerRpcEndpoint(RpcEndpointVerifier.NAME, new RpcEndpointVerifier(this, dispatcher))  
11)    override def setupEndpoint(name: String, endpoint: RpcEndpoint): RpcEndpointRef = {
    dispatcher.registerRpcEndpoint(name, endpoint)}  
//向RpcEnv注册RpcEndPoint与name
12)    private def postToOutbox(receiver: NettyRpcEndpointRef, message: OutboxMessage): Unit = {    
    if (receiver.client != null) {  
      message.sendWith(receiver.client)//若NettyRpcEndPointRef对应的客户端存在,则以客户端发送消息
    } else {
        targetOutbox.send(message) //否则根据远程Rpc地址找到对应的Outbox发送消息
}     
13)    private[netty] def send(message: RequestMessage): Unit = {    
    val remoteAddr = message.receiver.address    
    if (remoteAddr == address) {      
      // Message to a local RPC endpoint.      
        dispatcher.postOneWayMessage(message)      
//若地址为本地Rpc地址,则直接把消息发送给本地RpcEndPoint
    } else {       
      postToOutbox(message.receiver, OneWayOutboxMessage(serialize(message)))    
    }  
//把消息发送给远程RpcEndPoint
14)    private[netty] def createClient(address: RpcAddress): TransportClient = { clientFactory.createClient(address.host, address.port)  }
//根据host和port创建客户端
15)       private[netty] def ask[T: ClassTag](message: RequestMessage, timeout: RpcTimeout): Future[T] = {
       
    val remoteAddr = message.receiver.address    
    def onSuccess(reply: Any): Unit = reply match {     
      try {      
        if (remoteAddr == address) {       
          dispatcher.postLocalMessage(message, p)  
//若为本地地址,则发送给本地RpcEndPoint消息    
        } else {        
         postToOutbox(message.receiver, rpcMessage)        
        //若为远程Rpc地址,则发送给远程RpcEndPoint消息
       }      
 
Dispacher介绍
其主要作用在于保存注册的RpcEndPoint,以及分发Message到RpcEndPoint中处理
1)private class EndpointData(val name: String,  val endpoint: RpcEndpoint,  val ref:   NettyRpcEndpointRef) {
    val inbox = new Inbox(ref, endpoint)  
  } 
//定义一个内部类EndPointData.Inbox()保存RpcEndPointRef与RpcEndPoint的信息
2)// 维护一个HaskMap,保存Name与EndpointData的关系
  private val endpoints = new ConcurrentHashMap[String, EndpointData]  
  // 维护一个HaskMap,保存RpcEndpoint与RpcEndpointRef的关系
  private val endpointRefs = new ConcurrentHashMap[RpcEndpoint, RpcEndpointRef]
3)//维护一个BlockingQueue的队列,用于保存拥有消息的EndpointData,注册Endpoint、
  //发送消息时、停止RpcEnv时、取消注册的Endpoint时,会在receivers中添加相应的EndpointData,等待处理内部保存的消息
  private val receivers = new LinkedBlockingQueue[EndpointData] 
4)def registerRpcEndpoint(name: String, endpoint: RpcEndpoint): NettyRpcEndpointRef = {
    //根据NettyEnv的address和参数Name,创建RpcEndpointAddress
    val addr = RpcEndpointAddress(nettyEnv.address, name)
    //创建对应的NettyRpcEndpointRef
    val endpointRef = new NettyRpcEndpointRef(nettyEnv.conf, addr, nettyEnv)
      //将新创建的EndpointData和对应的Name添加到endpoints中
      if (endpoints.putIfAbsent(name, new EndpointData(name, endpoint, endpointRef)) != null) 
      val data = endpoints.get(name)
      //将endpoint和对应的endpointRef添加到endpointRefs中
      endpointRefs.put(data.endpoint, data.ref)
      //在receivers中添加新创建的endpointData
      receivers.offer(data)
      // for the OnStart message
    }
    //返回对应的EndpointRef
    endpointRef
  }
5)//发布一个由远端endpoint发送的消息
  def postRemoteMessage(message: RequestMessage, callback: RpcResponseCallback): Unit = {
    val rpcCallContext =  new RemoteNettyRpcCallContext(nettyEnv, callback, message.senderAddress)
    val rpcMessage = RpcMessage(message.senderAddress, message.content, rpcCallContext)    
    postMessage(message.receiver.name, rpcMessage, (e) => callback.onFailure(e))  
  }  
6)//发布一个由本地endpoint发送的消息
  def postLocalMessage(message: RequestMessage, p: Promise[Any]): Unit = {    
    val rpcCallContext = new LocalNettyRpcCallContext(message.senderAddress, p)    
    val rpcMessage = RpcMessage(message.senderAddress, message.content, rpcCallContext)    
    postMessage(message.receiver.name, rpcMessage, (e) => p.tryFailure(e))  
  }  
7)def postOneWayMessage(message: RequestMessage): Unit = {
    postMessage(message.receiver.name, OneWayMessage(message.senderAddress, message.content),      (e) => throw e)  
  }  
//发送一路消息
8) private def postMessage(endpointName: String,  message: InboxMessage,  callbackIfStopped: (Exception) => Unit): Unit = {
      val data = endpoints.get(endpointName)
        //将Message添加到该endpointData的inbox的message中
        data.inbox.post(message)
        //将endpointData添加到receivers中
        receivers.offer(data)}
9) //创建一个线程组,用于分发消息
  private val threadpool: ThreadPoolExecutor = {
    //根据配置项,获的线程组中线程个数
    val numThreads = nettyEnv.conf.getInt("spark.rpc.netty.dispatcher.numThreads",      math.max(2, Runtime.getRuntime.availableProcessors()))
    //创建线程组
    val pool = ThreadUtils.newDaemonFixedThreadPool(numThreads, "dispatcher-event-loop")
    //创建多线程,执行相应的MessageLoop
    for (i <- 0 until numThreads) {      
      pool.execute(new MessageLoop)    
    }    
    pool  
  } 
10) private class MessageLoop extends Runnable {    
    override def run(): Unit = {      
      try {        
        while (true) {
          try {
            //从receivers中获得一个endpointData,由于receivers是LinkBlockingQueue,所以如果receivers中没有元素时,该线程会阻塞
            val data = receivers.take()
            //获取的元素如果是PoisonPill,将停止该线程,同时 将PoisonPill继续放回receivers中,以便停止所有线程
            if (data == PoisonPill) {
            // Put PoisonPill back so that other MessageLoops can see it.
              receivers.offer(PoisonPill)
              return
            }
            //调用rpcEndpointData中inbox的process方法,处理响应RpcEndpointData中的Message
              data.inbox.process(Dispatcher.this)
          } catch {
            case NonFatal(e) => logError(e.getMessage, e)
          }
        }
      } catch {
        case ie: InterruptedException => // exit
      }
    }
  }
声明线程组,并监控receives是否有新的EndpointData
如果有消息,且消息不为PoisonPill,调用EndpointData的Inbox的process进行消息处理,消息处理过程是:
依次从相应的EndpointData的Inbox的messages中获取第一个元素
匹配消息,并调用对应的endpoint的相应方法进行处理
如果没有消息,则阻塞等待
如果有消息,但是为PoisonPill,则将PoisonPill继续添加到receivers中,然后停止该线程
 
NettyRpcEndpointRef
private[netty] class NettyRpcEndpointRef( @transient private val conf: SparkConf,    endpointAddress: RpcEndpointAddress,    @transient @volatile private var nettyEnv: NettyRpcEnv)  extends RpcEndpointRef(conf) with Serializable with Logging{
 
//声明一个transportClient
  @transient @volatile var client: TransportClient = _ 
   //根据endpointAddress获得NettyRpcEnv的host地址
  private val _address = if (endpointAddress.rpcAddress != null) endpointAddress else null
  //声明一个_name变量并赋值为endpointAddress的Name
  private val _name = endpointAddress.name
  
  override def address: RpcAddress = if (_address != null) _address.rpcAddress else null
  //读对象
  private def readObject(in: ObjectInputStream): Unit = {    
    in.defaultReadObject()    
    nettyEnv = NettyRpcEnv.currentEnv.value    
    client = NettyRpcEnv.currentClient.value  
  }  
  //写对象
  private def writeObject(out: ObjectOutputStream): Unit = {    
    out.defaultWriteObject()  
  }  
  
  override def name: String = _name  
  //重写RPCEndpointRef的ask方法
  override def ask[T: ClassTag](message: Any, timeout: RpcTimeout): Future[T] = {
    nettyEnv.ask(RequestMessage(nettyEnv.address, this, message), timeout)  
  }  
  //重写RPCEndpointRef的send方法
  override def send(message: Any): Unit = {    
    require(message != null, "Message is null")
    nettyEnv.send(RequestMessage(nettyEnv.address, this, message))  
  }  
}
 
 
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值