Golang 原生Rpc Server实现


引言

本文我们来看看golang原生rpc库的实现 , 首先来看一下golang rpc库的demo案例:

  • 服务端和客户端公共代码
type HelloService interface {
	Hello(request *Request, response *Response) error
}

type Request struct {
	Header map[string]interface{}
	Params map[string]interface{}
}

type Response struct {
	Header map[string]interface{}
	Params map[string]interface{}
}
  • 服务端代码
type HelloServiceImpl int

func NewServer() {
	helloImpl := new(HelloServiceImpl)
	rpc.RegisterName("helloService", helloImpl)
	rpc.HandleHTTP()
	if err := http.ListenAndServe(":1235", nil); err != nil {
		log.Fatal("server error: ", err)
	}
}

func (s *HelloServiceImpl) Hello(request *common.Request, response *common.Response) error {
	response.Header = request.Header
	response.Params = map[string]interface{}{
		"data": "Hello World",
	}
	return nil
}
  • 客户端代码
func NewClient() *common.Response {
	client, err := rpc.DialHTTP("tcp", ":1234")
	if err != nil {
		log.Fatal("dialing: ", err)
	}

	res := &common.Response{}

	err = client.Call("helloService.Hello", &common.Request{
		map[string]interface{}{
			"client": "val1",
		}, map[string]interface{}{
			"data": "hello world",
		},
	}, res)

	if err != nil {
		log.Fatal("call: ", err)
	}
	return res
}

golang 原生 rpc 库的使用基本还是分为两步走:

server 端 :

  1. 服务注册
  2. 启动服务

server端对注册的方法有一定的限制,方法必须满足签名:

func (t *T) MethodName(argType T1, replyType *T2) error
  • 首先,方法必须是导出的(名字首字母大写);
  • 其次,方法接受两个参数,必须是导出的或内置类型。第一个参数表示客户端传递过来的请求参数,第二个是需要返回给客户端的响应。第二个参数必须为指针类型(需要修改);
  • 最后,方法必须返回一个error类型的值。返回非nil的值,表示调用出错。

rpc.HandleHTTP()注册 HTTP 路由。http.ListenAndServe(“:1234”, nil)在端口1234上启动一个 HTTP 服务,请求 rpc 方法会交给rpc内部路由处理。这样我们就可以通过客户端调用这两个方法了。


client 端 :

  1. 连接服务端
  2. 调用接口

客户端比服务端稍微简单一点,我们使用rpc.DialHTTP(“tcp”, “:1234”)连接到服务端的监听地址,返回一个 rpc 的客户端对象。后续就可以调用该对象的Call()方法调用服务端对象的对应方法,依次传入方法名(需要加上类型限定)、参数、一个指针(用于接收返回值)


源码解析

对net/http包不熟悉的童鞋可能会觉得奇怪,rpc.HandleHTTP()与http.ListenAndServer(“:1234”, nil)是怎么联系起来的?我们简单看一下源码:

// src/net/rpc/server.go
const (
  // Defaults used by HandleHTTP
  DefaultRPCPath   = "/_goRPC_"
  DefaultDebugPath = "/debug/rpc"
)

func (server *Server) HandleHTTP(rpcPath, debugPath string) {
  http.Handle(rpcPath, server)
  http.Handle(debugPath, debugHTTP{server})
}

func HandleHTTP() {
  DefaultServer.HandleHTTP(DefaultRPCPath, DefaultDebugPath)
}

实际上,rpc.HandleHTTP()会调用http.Handle()在预定义的路径上(/_goRPC_)注册处理器。这个处理器最终被添加到net/http包中的默认多路复用器上:

// src/net/http/server.go
func Handle(pattern string, handler Handler) {
  DefaultServeMux.Handle(pattern, handler)
}

而http.ListenAndServer()第二个参数传入nil时也是使用默认的多路复用器。

有关golang http server 实现,可阅读:

细心的朋友可能发现了,除了默认的路径/_goRPC_用来处理 RPC 请求,rpc.HandleHTTP()方法还注册了一个调试路径/debug/rpc。我们可以直接在浏览器中访问这个网址(需要服务端程序开启。如果服务端在远程,需要相应地修改地址)localhost:1234,直观的查看各个方法的调用情况:

在这里插入图片描述


当我们访问/_goRPC_路径 , 最终调用到的请求处理器是net/rpc/server包下的ServerHttp函数:

func (server *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) {
	if req.Method != "CONNECT" {
		w.Header().Set("Content-Type", "text/plain; charset=utf-8")
		w.WriteHeader(http.StatusMethodNotAllowed)
		io.WriteString(w, "405 must CONNECT\n")
		return
	}
	// 拦截http连接拦截,获取原生的connection
	conn, _, err := w.(http.Hijacker).Hijack()
	...
	io.WriteString(conn, "HTTP/1.0 "+connected+"\n\n")
	// 连接上后续的数据读写都走rpc 协议 , 不走 http 协议了
	server.ServeConn(conn)
}

服务端

数据结构

首先来看一下承载Rpc服务核心状态的Server结构体实现:

// Server represents an RPC Server.
type Server struct {
	serviceMap sync.Map   // map[string]*service. 服务信息映射集合
	reqLock    sync.Mutex // protects freeReq.  
	freeReq    *Request 
	respLock   sync.Mutex // protects freeResp
	freeResp   *Response
}

其次是包含了注册服务信息的service结构体实现:

type service struct {
	name   string                 // 服务名
	rcvr   reflect.Value          // 服务实现类
	typ    reflect.Type           // 服务实现类类型
	method map[string]*methodType // 当前服务接口注册进来的方法列表
}

type methodType struct {
	sync.Mutex 
	method     reflect.Method
	ArgType    reflect.Type 
	ReplyType  reflect.Type
	numCalls   uint
}

下面是golang rpc通信使用到的请求和响应对象结构 , 请求和响应对象都会采用对象池进行复用,所以都有next属性:

// Request is a header written before every RPC call. It is used internally
// but documented here as an aid to debugging, such as when analyzing
// network traffic.
type Request struct {
	ServiceMethod string   // format: "Service.Method"
	Seq           uint64   // sequence number chosen by client
	next          *Request // for free list in Server
}

// Response is a header written before every RPC return. It is used internally
// but documented here as an aid to debugging, such as when analyzing
// network traffic.
type Response struct {
	ServiceMethod string    // echoes that of the Request
	Seq           uint64    // echoes that of the request
	Error         string    // error, if any.
	next          *Response // for free list in Server
}

服务注册

通过调用RegisterName函数,我们可以向rpc server的服务映射集合中保存当前服务信息:

// 服务名 , 服务实现类
func RegisterName(name string, rcvr any) error {
	return DefaultServer.RegisterName(name, rcvr)
}

func (server *Server) RegisterName(name string, rcvr any) error {
	return server.register(rcvr, name, true)
}

func (server *Server) register(rcvr any, name string, useName bool) error {
    // 创建一个新的服务信息类
	s := new(service)
	// 反射获取当前服务实现类的类型和值
	s.typ = reflect.TypeOf(rcvr)
	s.rcvr = reflect.ValueOf(rcvr)
	// 保存服务名
	sname := name
	// useName 表示是否使用传入的name作为服务名 , 如果为false , 则采用服务实现类的类型名
	if !useName {
		sname = reflect.Indirect(s.rcvr).Type().Name()
	}
	if sname == "" {
		s := "rpc.Register: no service name for type " + s.typ.String()
		log.Print(s)
		return errors.New(s)
	}
	// 如果采用服务实现类的类型名作为服务名,要确保服务实现类是导出的,对外可见
	if !useName && !token.IsExported(sname) {
		s := "rpc.Register: type " + sname + " is not exported"
		log.Print(s)
		return errors.New(s)
	}
	s.name = sname

	// 构建注册服务方法列表信息
	s.method = suitableMethods(s.typ, logRegisterError)

	if len(s.method) == 0 {
		str := ""

		// To help the user, see if a pointer receiver would work.
		method := suitableMethods(reflect.PointerTo(s.typ), false)
		if len(method) != 0 {
			str = "rpc.Register: type " + sname + " has no exported methods of suitable type (hint: pass a pointer to value of that type)"
		} else {
			str = "rpc.Register: type " + sname + " has no exported methods of suitable type"
		}
		log.Print(str)
		return errors.New(str)
	}
    // 判断服务名是否重复
	if _, dup := server.serviceMap.LoadOrStore(sname, s); dup {
		return errors.New("rpc: service already defined: " + sname)
	}
	return nil
}

suitableMethods方法用于遍历当前服务实现类所有导出方法,并筛选出符合RPC调用格式的方法列表:

func suitableMethods(typ reflect.Type, logErr bool) map[string]*methodType {
	methods := make(map[string]*methodType)
	// 遍历当前服务实现类的所有方法
	for m := 0; m < typ.NumMethod(); m++ {
	    // 定位方法元数据对象
		method := typ.Method(m)
		// 获取方法类型和方法名
		mtype := method.Type
		mname := method.Name
		// 跳过未导出的方法
		if !method.IsExported() {
			continue
		}
		// Method needs three ins: receiver, *args, *reply.
		// 方法参数必须有两个,第一个用于作为请求参数,第二个用于接收请求结果
		if mtype.NumIn() != 3 {
			if logErr {
				log.Printf("rpc.Register: method %q has %d input parameters; needs exactly three\n", mname, mtype.NumIn())
			}
			continue
		}
		// First arg need not be a pointer.
		// 第一个参数可以不是指针类型
		argType := mtype.In(1)
		if !isExportedOrBuiltinType(argType) {
			if logErr {
				log.Printf("rpc.Register: argument type of method %q is not exported: %q\n", mname, argType)
			}
			continue
		}
		// Second arg must be a pointer.
		// 第二个参数必须是指针类型
		replyType := mtype.In(2)
		if replyType.Kind() != reflect.Pointer {
			if logErr {
				log.Printf("rpc.Register: reply type of method %q is not a pointer: %q\n", mname, replyType)
			}
			continue
		}
		// Reply type must be exported.
		// 第二个参数类型必须是导出的
		if !isExportedOrBuiltinType(replyType) {
			if logErr {
				log.Printf("rpc.Register: reply type of method %q is not exported: %q\n", mname, replyType)
			}
			continue
		}
		// 方法必须只有一个返回值,同时返回值类型必须是error类型
		// Method needs one out.
		if mtype.NumOut() != 1 {
			if logErr {
				log.Printf("rpc.Register: method %q has %d output parameters; needs exactly one\n", mname, mtype.NumOut())
			}
			continue
		}
		// The return type of the method must be error.
		if returnType := mtype.Out(0); returnType != typeOfError {
			if logErr {
				log.Printf("rpc.Register: return type of method %q is %q, must be error\n", mname, returnType)
			}
			continue
		}
		// 构造方法类型信息: 方法元数据本身,方法第一个入参类型,方法第二个入参类型
		methods[mname] = &methodType{method: method, ArgType: argType, ReplyType: replyType}
	}
	return methods
}

请求处理

本文一开始给出的Demo是借助 Http Server 来 Accept 用户连接,当接收到用户连接后,会通过Hijack获取到原生连接,然后后续该连接上的客户端读写事件都采用gob编码进行通信,而非http协议了:

func (server *Server) ServeConn(conn io.ReadWriteCloser) {
	buf := bufio.NewWriter(conn)
	// 构建gob编码器
	srv := &gobServerCodec{
		rwc:    conn,
		dec:    gob.NewDecoder(conn),
		enc:    gob.NewEncoder(buf),
		encBuf: buf,
	}
	// 使用gob编码器从连接到读取字节流,然后按照golang RPC协议执行反序列化
	server.ServeCodec(srv)
}

ServeCodec 函数会按照gob编码反序列化得到RPC请求头和请求数据,然后调用目标,最终将结果按gob编码执行序列化,写会connection中:

func (server *Server) ServeCodec(codec ServerCodec) {
	sending := new(sync.Mutex)
	wg := new(sync.WaitGroup)
	for {
	    // 解析得到请求数据
		service, mtype, req, argv, replyv, keepReading, err := server.readRequest(codec)
		if err != nil {
			if debugLog && err != io.EOF {
				log.Println("rpc:", err)
			}
			// 读取完所有请求后,退出循环
			if !keepReading {
				break
			}
			// send a response if we actually managed to read a header.
			if req != nil {
				server.sendResponse(sending, req, invalidRequest, codec, err.Error())
				server.freeRequest(req)
			}
			continue
		}
		wg.Add(1)
		// 处理请求调用
		go service.call(server, sending, wg, mtype, req, argv, replyv, codec)
	}
	// We've seen that there are no more requests.
	// Wait for responses to be sent before closing codec.
	// 等待所有响应被处理完毕
	wg.Wait()
	codec.Close()
}

golang rpc 调用,发出的请求数据由两部分组成,首先是请求头,其次是RPC函数入参数的第一个对象,同样也是按照这个顺序依次执行反序列化读取:

func (server *Server) readRequest(codec ServerCodec) (service *service, mtype *methodType, req *Request, argv, replyv reflect.Value, keepReading bool, err error) {
    // 解析请求头
	service, mtype, req, keepReading, err = server.readRequestHeader(codec)
	if err != nil {
		if !keepReading {
			return
		}
		// discard body
		codec.ReadRequestBody(nil)
		return
	}
     
	// Decode the argument value.
	argIsValue := false 
	// 如果rpc方法的第一个参数(请求参数)类型是指针,则解引用拿到原始类型
	// 然后以原始类型分配一块新的内存,返回指向该内存的指针
	if mtype.ArgType.Kind() == reflect.Pointer {
		argv = reflect.New(mtype.ArgType.Elem())
	} else {
		argv = reflect.New(mtype.ArgType)
		argIsValue = true
	}
	// 反序列化得到请求参数的具体值,设置到argv指向到的零值结构体中
	if err = codec.ReadRequestBody(argv.Interface()); err != nil {
		return
	}
	// 如果目标RPC方法的请求入参是值类型,则进行解引用
	if argIsValue {
		argv = argv.Elem()
	}
    
    // 为第二个参数(返回值参数)同样初始化零值
	replyv = reflect.New(mtype.ReplyType.Elem())
   // 如果返回值参数类型为Map或者Slice,则初始化空map或切片
	switch mtype.ReplyType.Elem().Kind() {
	case reflect.Map:
		replyv.Elem().Set(reflect.MakeMap(mtype.ReplyType.Elem()))
	case reflect.Slice:
		replyv.Elem().Set(reflect.MakeSlice(mtype.ReplyType.Elem(), 0, 0))
	}
	return
}

golang rpc 请求头由调用方法信息和请求序列号组成 , 反序列化后,可以拿到服务名和方法名,根据方法名去server的服务映射集合中定位具体的方法元数据对象:

func (server *Server) readRequestHeader(codec ServerCodec) (svc *service, mtype *methodType, req *Request, keepReading bool, err error) {
	// 从请求对象池中获取一个空闲的请求对象
	req = server.getRequest()
	// 采用gob编码器将请求头部分字节流反序列化为req对象类型
	err = codec.ReadRequestHeader(req)
	if err != nil {
		req = nil
		// 字节流读完了
		if err == io.EOF || err == io.ErrUnexpectedEOF {
			return
		}
		err = errors.New("rpc: server cannot decode request: " + err.Error())
		return
	}

	// We read the header successfully. If we see an error now,
	// we can still recover and move on to the next request.
	keepReading = true
    // 分割得到服务名和客户端想要调用的方法名 
	dot := strings.LastIndex(req.ServiceMethod, ".")
	if dot < 0 {
		err = errors.New("rpc: service/method request ill-formed: " + req.ServiceMethod)
		return
	}
	serviceName := req.ServiceMethod[:dot]
	methodName := req.ServiceMethod[dot+1:]

	// Look up the request.
	// 根据服务名加载对应的服务信息类
	svci, ok := server.serviceMap.Load(serviceName)
	if !ok {
		err = errors.New("rpc: can't find service " + req.ServiceMethod)
		return
	}
	// 拿到服务信息类后,根据方法名定位获取到对应的方法类型
	svc = svci.(*service)
	mtype = svc.method[methodName]
	if mtype == nil {
		err = errors.New("rpc: can't find method " + req.ServiceMethod)
	}
	return
}

反序列化拿到请求数据后,便可以查询服务映射集合拿到对应的方法信息,最后我们便可以借助反射完成方法调用了:

func (s *service) call(server *Server, sending *sync.Mutex, wg *sync.WaitGroup, mtype *methodType, req *Request, argv, replyv reflect.Value, codec ServerCodec) {
	if wg != nil {
		defer wg.Done()
	}
	mtype.Lock()
	// 当前方法调用次数加一
	mtype.numCalls++
	mtype.Unlock()
	// 拿到方法句柄
	function := mtype.method.Func
	// 传入方法实际调用者,即服务实现类,方法的第一个和第二个请求参数
	returnValues := function.Call([]reflect.Value{s.rcvr, argv, replyv})
	// 方法执行完毕后,拿到方法返回值 -- 代表error
	errInter := returnValues[0].Interface()
	errmsg := ""
	if errInter != nil {
		errmsg = errInter.(error).Error()
	}
	// 发送响应给客户端
	server.sendResponse(sending, req, replyv.Interface(), codec, errmsg)
	// 释放当前请求对象到对象池中
	server.freeRequest(req)
}

本地方法执行完毕后,需要组装响应对象,然后将响应对象执行gob编码,然后发送到连接中:

var invalidRequest = struct{}{}

func (server *Server) sendResponse(sending *sync.Mutex, req *Request, reply any, codec ServerCodec, errmsg string) {
    // 从响应池中获取到空闲的响应对象
	resp := server.getResponse()
	// Encode the response header
	// 组装响应对象
	resp.ServiceMethod = req.ServiceMethod
	if errmsg != "" {
		resp.Error = errmsg
		reply = invalidRequest
	}
	resp.Seq = req.Seq
	// 将响应对象执行gob编码,然后发送到conn中
	sending.Lock()
	err := codec.WriteResponse(resp, reply)
	if debugLog && err != nil {
		log.Println("rpc: writing response:", err)
	}
	sending.Unlock()
	// 将响应对象返回到对象池中
	server.freeResponse(resp)
}

客户端

数据结构

首先是代表客户端对象的Client结构:

// Client represents an RPC Client.
// There may be multiple outstanding Calls associated
// with a single Client, and a Client may be used by
// multiple goroutines simultaneously.
type Client struct {
	codec ClientCodec  // 请求数据编解码器,默认是gob协议

	reqMutex sync.Mutex // protects following
	request  Request  // 此处请求对象结构复用了/rpc/server包下的请求对象结构

	mutex    sync.Mutex // protects following
	seq      uint64.  // 请求序列号
	pending  map[uint64]*Call // 已经发出但还未回复的rpc调用
	closing  bool // user has called Close
	shutdown bool // server has told us to stop
}

Call 结构体承载了RPC远程调用的上下文信息

// Call represents an active RPC.
type Call struct {
	ServiceMethod string     // The name of the service and method to call.
	Args          any        // The argument to the function (*struct).
	Reply         any        // The reply from the function (*struct).
	Error         error      // After completion, the error status.
	Done          chan *Call // Receives *Call when Go is complete.
}

建立连接

当服务端采用HTTP协议来接收客户端连接时,客户端就必须通过调用DialHttp来与服务端建立连接:

func DialHTTP(network, address string) (*Client, error) {
    // 使用默认的RPC建立连接的请求路径: /_goRPC_ 
	return DialHTTPPath(network, address, DefaultRPCPath)
}

func DialHTTPPath(network, address, path string) (*Client, error) {
    // 建立TCP连接
	conn, err := net.Dial(network, address)
	if err != nil {
		return nil, err
	}
	// 发出connect请求
	io.WriteString(conn, "CONNECT "+path+" HTTP/1.0\n\n")

	// Require successful HTTP response
	// before switching to RPC protocol.
	// 再转换为采用RPC协议通信时,需要确保此处的响应是成功的
	resp, err := http.ReadResponse(bufio.NewReader(conn), &http.Request{Method: "CONNECT"})
	if err == nil && resp.Status == connected {
		return NewClient(conn), nil
	}
	if err == nil {
		err = errors.New("unexpected HTTP response: " + resp.Status)
	}
	conn.Close()
	return nil, &net.OpError{
		Op:   "dial-http",
		Net:  network + " " + address,
		Addr: nil,
		Err:  err,
	}
}

当成功连接服务端时,会创建一个新的客户端对象并返回:

func NewClient(conn io.ReadWriteCloser) *Client {
	encBuf := bufio.NewWriter(conn)
	// client端默认采用gob编码
	client := &gobClientCodec{conn, gob.NewDecoder(conn), gob.NewEncoder(encBuf), encBuf}
	return NewClientWithCodec(client)
}

但是在一个新的客户端初始化时,会启动一个永不停歇的协程来不断接收并处理来自服务端的响应数据:

func NewClientWithCodec(codec ClientCodec) *Client {
	client := &Client{
		codec:   codec,
		pending: make(map[uint64]*Call),
	}
	// 启动一个永不停歇的协程来不断接收并处理来自服务端的响应数据
	go client.input()
	return client
}

input 协程采用死循环来不断读取服务端响应,并进行处理:

func (client *Client) input() {
	var err error
	var response Response // rpc/server包下的Response对象
	// 死循环来不断接收服务端响应,直到解析请求体的过程中出现错误,才会退出循环
	for err == nil {
		response = Response{} 
		// 读取响应头
		err = client.codec.ReadResponseHeader(&response)
		if err != nil {
			break
		}
		// 拿到响应序列号,得知该响应是对客户端发出的哪个请求的响应
		seq := response.Seq
		client.mutex.Lock()
		// 从pending集合中定位对应的call对象
		call := client.pending[seq]
		// 从集合中移除该对象
		delete(client.pending, seq)
		client.mutex.Unlock()

		switch {
		// 如果pending集合中不存在call对象,说明可能是重复响应,说明存在错误
		case call == nil:
			// We've got no pending call. That usually means that
			// WriteRequest partially failed, and call was already
			// removed; response is a server telling us about an
			// error reading request body. We should still attempt
			// to read error body, but there's no one to give it to.
			err = client.codec.ReadResponseBody(nil)
			if err != nil {
				err = errors.New("reading error body: " + err.Error())
			}
		// 响应头中错误信息不为空	
		case response.Error != "":
			// We've got an error response. Give this to the request;
			// any subsequent requests will get the ReadResponseBody
			// error if there is one.
			call.Error = ServerError(response.Error)
			err = client.codec.ReadResponseBody(nil)
			if err != nil {
				err = errors.New("reading error body: " + err.Error())
			}
			// 通知本次请求结束
			call.done()
	    // 正常响应 		
		default:
		    // 读取响应结果
			err = client.codec.ReadResponseBody(call.Reply)
			// 存在错误则记录
			if err != nil {
				call.Error = errors.New("reading body " + err.Error())
			}
			// 通知本次请求处理结束
			call.done()
		}
	}
	// 如果解析请求体的过程中出现错误,则退出上面的循环 
	// Terminate pending calls.
	client.reqMutex.Lock()
	client.mutex.Lock()
	client.shutdown = true
	closing := client.closing
	if err == io.EOF {
		if closing {
			err = ErrShutdown
		} else {
			err = io.ErrUnexpectedEOF
		}
	}
	// 终止所有已发送还未接收到响应的请求
	for _, call := range client.pending {
		call.Error = err
		call.done()
	}
	client.mutex.Unlock()
	client.reqMutex.Unlock()
	if debugLog && err != io.EOF && !closing {
		log.Println("rpc: client protocol error:", err)
	}
}

请求调用

rpc client端通过调用Call方法来完成远程过程调用:

func (client *Client) Call(serviceMethod string, args any, reply any) error {
	// 同步阻塞直到请求响应接收到为止,Done信号在input协程中被设置,或者请求发送过程中出现错误时被设置
	call := <-client.Go(serviceMethod, args, reply, make(chan *Call, 1)).Done
	return call.Error
}

func (client *Client) Go(serviceMethod string, args any, reply any, done chan *Call) *Call {
	// 构建请求调用对象
	call := new(Call)
	call.ServiceMethod = serviceMethod
	call.Args = args
	call.Reply = reply
	if done == nil {
		done = make(chan *Call, 10) // buffered.
	} else {
		// If caller passes done != nil, it must arrange that
		// done has enough buffer for the number of simultaneous
		// RPCs that will be using that channel. If the channel
		// is totally unbuffered, it's best not to run at all.
		if cap(done) == 0 {
			log.Panic("rpc: done channel is unbuffered")
		}
	}
	call.Done = done
	// 发送请求
	client.send(call)
	return call
}

实际请求发送会调用client的send方法完成:

func (client *Client) send(call *Call) {
	client.reqMutex.Lock()
	defer client.reqMutex.Unlock()

	// Register this call.
	client.mutex.Lock()
	if client.shutdown || client.closing {
		client.mutex.Unlock()
		call.Error = ErrShutdown
		call.done()
		return
	}
	// 为当前请求设置请求序列号,同时将当前请求调用添加进pending集合
	seq := client.seq
	client.seq++
	client.pending[seq] = call
	client.mutex.Unlock()

	// Encode and send the request.
	// 构建请求对象
	client.request.Seq = seq
	client.request.ServiceMethod = call.ServiceMethod
	// 发送请求 --- 此处发送完毕请求后,就直接返回了,不会等待响应结果
	err := client.codec.WriteRequest(&client.request, call.Args)
	if err != nil {
		client.mutex.Lock()
		call = client.pending[seq]
		delete(client.pending, seq)
		client.mutex.Unlock()
		if call != nil {
			call.Error = err
			call.done()
		}
	}
}

延伸

异步调用

上文中举的例子,客户端实际是同步调用模式,首先WriteRequest发送请求方法是异步的,但是Call方法会等待直到Done信号有值时,才会返回。

改造为异步模式也很简单,直接调用Go方法,并在合适的时机调用监听Done通道是否有值即可:

func NewClient() *common.Response {
	client, err := rpc.DialHTTP("tcp", ":1234")
	if err != nil {
		log.Fatal("dialing: ", err)
	}

	res := &common.Response{}

	call := client.Go("helloService.Hello", &common.Request{
		map[string]interface{}{
			"client": "val1",
		}, map[string]interface{}{
			"data": "hello world",
		},
	}, res, nil)

	ticker := time.NewTicker(time.Millisecond)
	defer ticker.Stop()

	select {
	case replyCall := <-call.Done:
		if err := replyCall.Error; err != nil {
			fmt.Println("rpc error:", err)
		} else {
			fmt.Printf("res= %v", replyCall)
		}
	case t := <-ticker.C:
		fmt.Println("Current time: ", t)
	}
	
	return res
}

定制服务名

默认情况下,rpc.Register()将方法接收者(receiver)的类型名作为服务名。我们也可以自己设置。这时需要调用RegisterName(name string, rcvr interface{}) error方法,我们一开始给出的例子就是采用了后者,忘记的可以回看源码。


采用TPC协议建立连接

上面我们都是使用 HTTP 协议来实现 rpc 服务的,rpc库也支持直接使用 TCP 协议。首先,服务端先调用net.Listen("tcp", ":1234")创建一个监听某个 TCP 端口的监听器(Accepter),然后使用rpc.Accept(l)在此监听器上接受连接并处理:

type HelloServiceImpl int

func NewServer() {
	helloImpl := new(HelloServiceImpl)
	l, err := net.Listen("tcp", ":1236")
	if err != nil {
		return
	}

	rpc.Register(helloImpl)
	rpc.Accept(l)
}

func (s *HelloServiceImpl) Hello(request *common.Request, response *common.Response) error {
	response.Header = request.Header
	response.Params = map[string]interface{}{
		"data": "Hello World",
	}
	return nil
}

此处就相当于建立连接的时候就不采用http的connect请求方式了,只要TCP连接建立成功,就认为RPC连接建立成功:

func Accept(lis net.Listener) { DefaultServer.Accept(lis) }

func (server *Server) Accept(lis net.Listener) {
	for {
		conn, err := lis.Accept()
		if err != nil {
			log.Print("rpc.Serve: accept:", err.Error())
			return
		}
		go server.ServeConn(conn)
	}
}

然后,客户端调用rpc.Dial()以 TCP 协议连接到服务端:

func NewClient() *common.Response {
	client, err := rpc.Dial("tcp", ":1236")
	if err != nil {
		log.Fatal("dialing: ", err)
	}

	res := &common.Response{}

	call := client.Go("helloService.Hello", &common.Request{
		map[string]interface{}{
			"client": "val1",
		}, map[string]interface{}{
			"data": "hello world",
		},
	}, res, nil)

	ticker := time.NewTicker(time.Millisecond)
	defer ticker.Stop()

	select {
	case replyCall := <-call.Done:
		if err := replyCall.Error; err != nil {
			fmt.Println("rpc error:", err)
		} else {
			fmt.Printf("res= %v", replyCall)
		}
	case t := <-ticker.C:
		fmt.Println("Current time: ", t)
	}

	return res
}

相比于基于Http协议建立连接的方式,此处就直接建立TCP连接就完事了,而无需再发送Connect请求:

// Dial connects to an RPC server at the specified network address.
func Dial(network, address string) (*Client, error) {
	conn, err := net.Dial(network, address)
	if err != nil {
		return nil, err
	}
	return NewClient(conn), nil
}

自定义编码格式

默认客户端与服务端之间的数据使用gob编码,我们可以使用其它的格式来编码。在服务端,我们要实现rpc.ServerCodec接口:

// src/net/rpc/server.go
type ServerCodec interface {
  ReadRequestHeader(*Request) error
  ReadRequestBody(interface{}) error
  WriteResponse(*Response, interface{}) error

  Close() error
}

实际上不用这么麻烦,我们查看源码看看gobServerCodec是怎么实现的,然后仿造实现一个就行了。下面我实现了一个 JSON 格式的编解码器:

type JsonServerCodec struct {
  rwc    io.ReadWriteCloser
  dec    *json.Decoder
  enc    *json.Encoder
  encBuf *bufio.Writer
  closed bool
}

func NewJsonServerCodec(conn io.ReadWriteCloser) *JsonServerCodec {
  buf := bufio.NewWriter(conn)
  return &JsonServerCodec{conn, json.NewDecoder(conn), json.NewEncoder(buf), buf, false}
}

func (c *JsonServerCodec) ReadRequestHeader(r *rpc.Request) error {
  return c.dec.Decode(r)
}

func (c *JsonServerCodec) ReadRequestBody(body interface{}) error {
  return c.dec.Decode(body)
}

func (c *JsonServerCodec) WriteResponse(r *rpc.Response, body interface{}) (err error) {
  if err = c.enc.Encode(r); err != nil {
    if c.encBuf.Flush() == nil {
      log.Println("rpc: json error encoding response:", err)
      c.Close()
    }
    return
  }
  if err = c.enc.Encode(body); err != nil {
    if c.encBuf.Flush() == nil {
      log.Println("rpc: json error encoding body:", err)
      c.Close()
    }
    return
  }
  return c.encBuf.Flush()
}

func (c *JsonServerCodec) Close() error {
  if c.closed {
    return nil
  }
  c.closed = true
  return c.rwc.Close()
}

server端的for循环中需要创建编解码器JsonServerCodec传给ServeCodec方法:

func NewServer() {
	helloImpl := new(HelloServiceImpl)
	l, err := net.Listen("tcp", ":1236")
	if err != nil {
		return
	}

	rpc.Register(helloImpl)

	for {
		conn, err := l.Accept()
		if err != nil {
			return
		}
		go rpc.ServeCodec(common.NewJsonServerCodec(conn))
	}
}

同样的,客户端要实现rpc.ClientCodec接口,也是仿造gobClientCodec的实现:

type JsonClientCodec struct {
  rwc    io.ReadWriteCloser
  dec    *json.Decoder
  enc    *json.Encoder
  encBuf *bufio.Writer
}

func NewJsonClientCodec(conn io.ReadWriteCloser) *JsonClientCodec {
  encBuf := bufio.NewWriter(conn)
  return &JsonClientCodec{conn, json.NewDecoder(conn), json.NewEncoder(encBuf), encBuf}
}

func (c *JsonClientCodec) WriteRequest(r *rpc.Request, body interface{}) (err error) {
  if err = c.enc.Encode(r); err != nil {
    return
  }
  if err = c.enc.Encode(body); err != nil {
    return
  }
  return c.encBuf.Flush()
}

func (c *JsonClientCodec) ReadResponseHeader(r *rpc.Response) error {
  return c.dec.Decode(r)
}

func (c *JsonClientCodec) ReadResponseBody(body interface{}) error {
  return c.dec.Decode(body)
}

func (c *JsonClientCodec) Close() error {
  return c.rwc.Close()
}

要使用NewClientWithCodec以指定的编解码器创建客户端:

func NewClient() *common.Response {
	conn, err := net.Dial("tcp", ":1234")
	if err != nil {
		return nil
	}

	client := rpc.NewClientWithCodec(common.NewJsonClientCodec(conn))
	res := &common.Response{}

	err = client.Call("helloService.Hello", &common.Request{
		map[string]interface{}{
			"client": "val1",
		}, map[string]interface{}{
			"data": "hello world",
		},
	}, res)

	return res
}

自定义服务器

实际上,上面我们调用的方法rpc.Register,rpc.RegisterName,rpc.ServeConn,rpc.ServeCodec都是转而去调用默认DefaultServer的相关方法:

// src/net/rpc/server.go
var DefaultServer = NewServer()

func Register(rcvr interface{}) error { return DefaultServer.Register(rcvr) }

func RegisterName(name string, rcvr interface{}) error {
  return DefaultServer.RegisterName(name, rcvr)
}

func ServeConn(conn io.ReadWriteCloser) {
  DefaultServer.ServeConn(conn)
}

func ServeCodec(codec ServerCodec) {
  DefaultServer.ServeCodec(codec)
}

但是因为DefaultServer是全局共享的,如果有第三方库使用了相关方法,并且注册了一些对象的方法,我们引用这个第三方库之后,就出现两个问题。第一,可能与我们注册的方法冲突;第二,带来额外的安全隐患(库中方法直接panic?)。故而推荐做法是自己NewServer:

func main() {
  arith := new(Arith)
  server := rpc.NewServer()
  server.RegisterName("math", arith)
  server.HandleHTTP(rpc.DefaultRPCPath, rpc.DefaultDebugPath)

  if err := http.ListenAndServe(":1234", nil); err != nil {
    log.Fatal("serve error:", err)
  }
}

这其实是一个套路,很多库会提供一个默认的实现直接使用,如log、net/http这些库。但是也提供了创建和自定义的方法。一般测试时为了方便可以使用默认实现,实践中最好自己创建相应的对象,避免干扰和安全问题。


参考

延伸部分主要摘录至: Go 每日一库之 rpc

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值