Spark RPC层是基于通信框架Netty开发的
核心类:RpcEnv,RpcEndPoint,RpcEndPointRef
比如Spark Master和Worker便是一个RpcEndPoint(通信端),只能通过RpcEnv获取与RpcEndPoint通信的对象RpcEndPointRef,Spark默认使用更高效的NettyRpcEnv
客户端调用RpcEndPointRef发送消息,首先通过RpcEnv处理消息,找到该消息应该发送给谁,则路由到RpcEndPoint
1. RpcEnv:是Rpc的环境变量,管理整个RpcEndPoint的生命周期,从注册RpcEndPoint,到消息处理,最后到停用RpcEndPoint;RpcEnv只能通过RpcEnvFactory创建;RpcEnv的核心方法:
def setupEndPoint(name:String,endpoint:RpcEndPoint):RpcEndPointRef //向RpcEnv中注册RpcEndPoint,管理者RpcEndPoint到RpcEndPointRef的绑定关系
2 RpcEndPoint:定义了RPC通信过程的对象,具有管理RpcEndPoint的生命周期的操作(constructor->onStart->receive->onStop),并给出了RpcEndPoint基于事件驱动的行为(连接,断开,网络异常),RpcEndPoint的核心是接收消息并处理
def receive:PartialFunction[Any,Unit]={
case_ =>thrownewSparkException(self +" does not implement 'receive'")
}
def receiveAndReply(context:RpcCallContext):PartialFunction[Any,Unit]={
case_ =>context.sendFailure(newSparkException(self +" won't reply anything"))}
RpcEndPoint.receive接收由RpcEndPointRef.send()发送的消息,直接由RpcEndPoint处理
RpcEndPoint.receiveAndReply()接收RpcEndPointRef.ask发送的消息,RpcEndPoint处理后,需要给RpcEndPointRef.ask发送响应消息
3.RpcEndPointRef
RpcEndPointRef是RpcEndPoint的远程对象类似于远程代理
private[spark] abstract class RpcEndpointRef(conf: SparkConf) extends Serializable with Logging {
private[this] val maxRetries = RpcUtils.numRetries(conf)
private[this] val retryWaitMs = RpcUtils.retryWaitMs(conf)
private[this] val defaultAskTimeout = RpcUtils.askRpcTimeout(conf)
def address: RpcAddress
def name: String
def send(message: Any): Unit
def ask[T: ClassTag](message: Any, timeout: RpcTimeout): Future[T]
def ask[T: ClassTag](message: Any): Future[T] = ask(message, defaultAskTimeout)
def askWithRetry[T: ClassTag](message: Any): T = askWithRetry(message, defaultAskTimeout)
def askWithRetry[T: ClassTag](message: Any, timeout: RpcTimeout): T = {
... ...//通过Future来异步获取响应结果
}
}
4.Spark Driver Env创建NettyRpcEnv
private[rpc] class NettyRpcEnvFactory extends RpcEnvFactory with Logging {
def create(config: RpcEnvConfig): RpcEnv = {
val sparkConf = config.conf
//创建序列化
val javaSerializerInstance = new JavaSerializer(sparkConf).newInstance().asInstanceOf[JavaSerializerInstance]
//new 一个NettyRpcEnv实例
val nettyEnv = new NettyRpcEnv(sparkConf, javaSerializerInstance, config.host, config.securityManager)
if (!config.clientMode) {//非客户端模式
val startNettyRpcEnv: Int => (NettyRpcEnv, Int) = {
actualPort => nettyEnv.startServer(actualPort)//启动服务器
(nettyEnv, nettyEnv.address.port)
}
try {
// 根据指定的端口号和主机,启动Driver Rpc服务
Utils.startServiceOnPort(config.port, startNettyRpcEnv, sparkConf, config.name)._1
}
catch {
case NonFatal(e) =>
nettyEnv.shutdown()
throw e
}
}
nettyEnv
}
}
由于NettyRpcEnv只能由NettyRpcEnvFactory.create()创建,其流程如下
1) 得到Spark Conf
2) 创建序列化
3) 创建NettyRpcEnv
4) 判断不为客户端模式,创建startNettyRpcEnv
5) Utils.startServiceOnPort启动Driver Rpc服务
NettyRpcEnv创建的主要流程:
1)private[netty] val transportConf = SparkTransportConf.fromSparkConf(conf.clone.set("spark.rpc.io.numConnectionsPerPeer", "1"), "rpc", conf.getInt("spark.rpc.io.threads", 0))
//根据Spark Conf初始化创建transportConf
2) private val dispatcher: Dispatcher = new Dispatcher(this)
//创建Dispatcher,主要用于消息的分发处理
3) private val streamManager = new NettyStreamManager(this)
//创建streamManager
4) private val transportContext = new TransportContext(transportConf, new NettyRpcHandler(dispatcher, this, streamManager))
//创建TransportContext用于创建Client和Server
5) private def createClientBootstraps():
//对权限判断并根据SecurityManager和TransportConf创建一组SaslClientBootstrap
// 声明一个clientFactory,用户创建通信的客户端
private val clientFactory = transportContext.createClientFactory(createClientBootstraps())
6) private val clientFactory = transportContext.createClientFactory(createClientBootstraps())//创建客户端工厂,用于通信的客户端只能由客户端工厂创建,客户端工厂用于文件下载
7) val timeoutScheduler = ThreadUtils.newDaemonSingleThreadScheduledExecutor("netty-rpc-env-timeout")
//创建一个netty-rpc-env-timeout的守护线程,监控客户端到服务器的连接
8)private[netty] val clientConnectionExecutor = ThreadUtils.newDaemonCachedThreadPool( "netty-rpc-connection", conf.getInt("spark.rpc.connect.threads", 64))
//创建客户端连接执行器,设置Rpc连接的线程数
9) private val outboxes = new ConcurrentHashMap[RpcAddress, Outbox]()
//创建并发哈希表记录远程Rpc地址到Outbox的映射,消息放入Outbox可以实现非阻塞send
10) def startServer(port: Int): Unit = {
//通过transportContext启动通信底层的服务端
server = transportContext.createServer(host, port, bootstraps)
//注册一个RpcEndpointVerifier,对Server进行验证
dispatcher.registerRpcEndpoint(RpcEndpointVerifier.NAME, new RpcEndpointVerifier(this, dispatcher))
11) override def setupEndpoint(name: String, endpoint: RpcEndpoint): RpcEndpointRef = {
dispatcher.registerRpcEndpoint(name, endpoint)}
//向RpcEnv注册RpcEndPoint与name
12) private def postToOutbox(receiver: NettyRpcEndpointRef, message: OutboxMessage): Unit = {
if (receiver.client != null) {
message.sendWith(receiver.client)//若NettyRpcEndPointRef对应的客户端存在,则以客户端发送消息
} else {
targetOutbox.send(message) //否则根据远程Rpc地址找到对应的Outbox发送消息
}
13) private[netty] def send(message: RequestMessage): Unit = {
val remoteAddr = message.receiver.address
if (remoteAddr == address) {
// Message to a local RPC endpoint.
dispatcher.postOneWayMessage(message)
//若地址为本地Rpc地址,则直接把消息发送给本地RpcEndPoint
} else {
postToOutbox(message.receiver, OneWayOutboxMessage(serialize(message)))
}
//把消息发送给远程RpcEndPoint
14) private[netty] def createClient(address: RpcAddress): TransportClient = { clientFactory.createClient(address.host, address.port) }
//根据host和port创建客户端
15) private[netty] def ask[T: ClassTag](message: RequestMessage, timeout: RpcTimeout): Future[T] = {
val remoteAddr = message.receiver.address
def onSuccess(reply: Any): Unit = reply match {
try {
if (remoteAddr == address) {
dispatcher.postLocalMessage(message, p)
//若为本地地址,则发送给本地RpcEndPoint消息
} else {
postToOutbox(message.receiver, rpcMessage)
//若为远程Rpc地址,则发送给远程RpcEndPoint消息
}
Dispacher介绍
其主要作用在于保存注册的RpcEndPoint,以及分发Message到RpcEndPoint中处理
1)private class EndpointData(val name: String, val endpoint: RpcEndpoint, val ref: NettyRpcEndpointRef) {
val inbox = new Inbox(ref, endpoint)
}
//定义一个内部类EndPointData.Inbox()保存RpcEndPointRef与RpcEndPoint的信息
2)// 维护一个HaskMap,保存Name与EndpointData的关系
private val endpoints = new ConcurrentHashMap[String, EndpointData]
// 维护一个HaskMap,保存RpcEndpoint与RpcEndpointRef的关系
private val endpointRefs = new ConcurrentHashMap[RpcEndpoint, RpcEndpointRef]
3)//维护一个BlockingQueue的队列,用于保存拥有消息的EndpointData,注册Endpoint、
//发送消息时、停止RpcEnv时、取消注册的Endpoint时,会在receivers中添加相应的EndpointData,等待处理内部保存的消息
private val receivers = new LinkedBlockingQueue[EndpointData]
4)def registerRpcEndpoint(name: String, endpoint: RpcEndpoint): NettyRpcEndpointRef = {
//根据NettyEnv的address和参数Name,创建RpcEndpointAddress
val addr = RpcEndpointAddress(nettyEnv.address, name)
//创建对应的NettyRpcEndpointRef
val endpointRef = new NettyRpcEndpointRef(nettyEnv.conf, addr, nettyEnv)
//将新创建的EndpointData和对应的Name添加到endpoints中
if (endpoints.putIfAbsent(name, new EndpointData(name, endpoint, endpointRef)) != null)
val data = endpoints.get(name)
//将endpoint和对应的endpointRef添加到endpointRefs中
endpointRefs.put(data.endpoint, data.ref)
//在receivers中添加新创建的endpointData
receivers.offer(data)
// for the OnStart message
}
//返回对应的EndpointRef
endpointRef
}
5)//发布一个由远端endpoint发送的消息
def postRemoteMessage(message: RequestMessage, callback: RpcResponseCallback): Unit = {
val rpcCallContext = new RemoteNettyRpcCallContext(nettyEnv, callback, message.senderAddress)
val rpcMessage = RpcMessage(message.senderAddress, message.content, rpcCallContext)
postMessage(message.receiver.name, rpcMessage, (e) => callback.onFailure(e))
}
6)//发布一个由本地endpoint发送的消息
def postLocalMessage(message: RequestMessage, p: Promise[Any]): Unit = {
val rpcCallContext = new LocalNettyRpcCallContext(message.senderAddress, p)
val rpcMessage = RpcMessage(message.senderAddress, message.content, rpcCallContext)
postMessage(message.receiver.name, rpcMessage, (e) => p.tryFailure(e))
}
7)def postOneWayMessage(message: RequestMessage): Unit = {
postMessage(message.receiver.name, OneWayMessage(message.senderAddress, message.content), (e) => throw e)
}
//发送一路消息
8) private def postMessage(endpointName: String, message: InboxMessage, callbackIfStopped: (Exception) => Unit): Unit = {
val data = endpoints.get(endpointName)
//将Message添加到该endpointData的inbox的message中
data.inbox.post(message)
//将endpointData添加到receivers中
receivers.offer(data)}
9) //创建一个线程组,用于分发消息
private val threadpool: ThreadPoolExecutor = {
//根据配置项,获的线程组中线程个数
val numThreads = nettyEnv.conf.getInt("spark.rpc.netty.dispatcher.numThreads", math.max(2, Runtime.getRuntime.availableProcessors()))
//创建线程组
val pool = ThreadUtils.newDaemonFixedThreadPool(numThreads, "dispatcher-event-loop")
//创建多线程,执行相应的MessageLoop
for (i <- 0 until numThreads) {
pool.execute(new MessageLoop)
}
pool
}
10) private class MessageLoop extends Runnable {
override def run(): Unit = {
try {
while (true) {
try {
//从receivers中获得一个endpointData,由于receivers是LinkBlockingQueue,所以如果receivers中没有元素时,该线程会阻塞
val data = receivers.take()
//获取的元素如果是PoisonPill,将停止该线程,同时 将PoisonPill继续放回receivers中,以便停止所有线程
if (data == PoisonPill) {
// Put PoisonPill back so that other MessageLoops can see it.
receivers.offer(PoisonPill)
return
}
//调用rpcEndpointData中inbox的process方法,处理响应RpcEndpointData中的Message
data.inbox.process(Dispatcher.this)
} catch {
case NonFatal(e) => logError(e.getMessage, e)
}
}
} catch {
case ie: InterruptedException => // exit
}
}
}
声明线程组,并监控receives是否有新的EndpointData
如果有消息,且消息不为PoisonPill,调用EndpointData的Inbox的process进行消息处理,消息处理过程是:
依次从相应的EndpointData的Inbox的messages中获取第一个元素
匹配消息,并调用对应的endpoint的相应方法进行处理
如果没有消息,则阻塞等待
如果有消息,但是为PoisonPill,则将PoisonPill继续添加到receivers中,然后停止该线程
NettyRpcEndpointRef
private[netty] class NettyRpcEndpointRef( @transient private val conf: SparkConf, endpointAddress: RpcEndpointAddress, @transient @volatile private var nettyEnv: NettyRpcEnv) extends RpcEndpointRef(conf) with Serializable with Logging{
//声明一个transportClient
@transient @volatile var client: TransportClient = _
//根据endpointAddress获得NettyRpcEnv的host地址
private val _address = if (endpointAddress.rpcAddress != null) endpointAddress else null
//声明一个_name变量并赋值为endpointAddress的Name
private val _name = endpointAddress.name
override def address: RpcAddress = if (_address != null) _address.rpcAddress else null
//读对象
private def readObject(in: ObjectInputStream): Unit = {
in.defaultReadObject()
nettyEnv = NettyRpcEnv.currentEnv.value
client = NettyRpcEnv.currentClient.value
}
//写对象
private def writeObject(out: ObjectOutputStream): Unit = {
out.defaultWriteObject()
}
override def name: String = _name
//重写RPCEndpointRef的ask方法
override def ask[T: ClassTag](message: Any, timeout: RpcTimeout): Future[T] = {
nettyEnv.ask(RequestMessage(nettyEnv.address, this, message), timeout)
}
//重写RPCEndpointRef的send方法
override def send(message: Any): Unit = {
require(message != null, "Message is null")
nettyEnv.send(RequestMessage(nettyEnv.address, this, message))
}
}