grpc源码解析

4 篇文章 0 订阅
2 篇文章 1 订阅

上一篇笔者总结的grpc的文档只是整体介绍了一下grpc的框架和特性,但是一些细节的实现和一些概念性知识点依然存在疑惑,查了一些资料,都比较千篇一律,因此现在尝试学习grpc的源码来解答自己的疑惑

目标

  • 了解grpc项目的总架构
  • 了解grpc连接池、client/server端的解析工作
  • 了解grpc拦截器等特性的设计模式

server端处理流程

	lis, err := net.Listen("tcp", port)
	if err != nil {
		logs.Error("failed to listen, err is %s", err)
	}
	s := grpc.NewServer()
	pb.RegisterUserServiceServer(s, &server{})
	if err := s.Serve(lis); err != nil {
		logs.Error("failed to serve, err is %s", err)
	}

1、新建一个tcp端口监听

lis, err := net.Listen("tcp", port)

2、创建server

s := grpc.NewServer()

该方法即创建server结

构体,为属性赋值,比如拦截器等,拦截器的流程后面会详细讲

server的strut如下

type Server struct {
	opts serverOptions

	mu     sync.Mutex // guards following
	lis    map[net.Listener]bool
	conns  map[transport.ServerTransport]bool
	serve  bool
	drain  bool
	cv     *sync.Cond          // signaled when connections close for GracefulStop
	m      map[string]*service // service name -> service info
	events trace.EventLog

	quit               *grpcsync.Event
	done               *grpcsync.Event
	channelzRemoveOnce sync.Once
	serveWG            sync.WaitGroup // counts active Serve goroutines for GracefulStop

	channelzID int64 // channelz unique identification number
	czData     *channelzData
}

其中比较重要的是几个map : list, cons m, list是tcp端口监听,m是service name对应的service信息

grpc.NewServer里可以携带参数,参数为opt …ServerOption,比如interceptors

3、注册server

func (s *Server) RegisterService(sd *ServiceDesc, ss interface{}) {
	ht := reflect.TypeOf(sd.HandlerType).Elem()
	st := reflect.TypeOf(ss)
	if !st.Implements(ht) {
		grpclog.Fatalf("grpc: Server.RegisterService found the handler of type %v that does not satisfy %v", st, ht)
	}
	s.register(sd, ss)
}

func (s *Server) register(sd *ServiceDesc, ss interface{}) {
	s.mu.Lock()
	defer s.mu.Unlock()
	s.printf("RegisterService(%q)", sd.ServiceName)
	if s.serve {
		grpclog.Fatalf("grpc: Server.RegisterService after Server.Serve for %q", sd.ServiceName)
	}
	if _, ok := s.m[sd.ServiceName]; ok {
		grpclog.Fatalf("grpc: Server.RegisterService found duplicate service registration for %q", sd.ServiceName)
	}
	srv := &service{
		server: ss,
		md:     make(map[string]*MethodDesc),
		sd:     make(map[string]*StreamDesc),
		mdata:  sd.Metadata,
	}
	for i := range sd.Methods {
		d := &sd.Methods[i]
		srv.md[d.MethodName] = d
	}
	for i := range sd.Streams {
		d := &sd.Streams[i]
		srv.sd[d.StreamName] = d
	}
	s.m[sd.ServiceName] = srv
}

代码很简单,其实就是将service的信息注入到server结构体中,比如将server的name注入到,还将每个service的每个method name封装成map,通过method name,可以找到对应的handler

ServiceDesc的结构如下:

type ServiceDesc struct {
	ServiceName string
	// The pointer to the service interface. Used to check whether the user
	// provided implementation satisfies the interface requirements.
	HandlerType interface{}
	Methods     []MethodDesc
	Streams     []StreamDesc
	Metadata    interface{}
}

Method Desc在实际业务中类似下面

var _UserService_serviceDesc = grpc.ServiceDesc{
	ServiceName: "service.UserService",
	HandlerType: (*UserServiceServer)(nil),
	Methods: []grpc.MethodDesc{
		{
			MethodName: "login",
			Handler:    _UserService_Login_Handler,
		},
		{
			MethodName: "getUserInfo",
			Handler:    _UserService_GetUserInfo_Handler,
		},
		{
			MethodName: "updateUserInfo",
			Handler:    _UserService_UpdateUserInfo_Handler,
		},
		{
			MethodName: "uploadProfilePic",
			Handler:    _UserService_UploadProfilePic_Handler,
		},
	},
	Streams:  []grpc.StreamDesc{},
	Metadata: "userService.proto",
}

代码中遇到的stream,可以先忽略,这个跟grpc中stream模式有关

4、grpc server的真正启动

上述的几步只是注册server和启动tcp端口,grpc server还未真正启动起来, grpc server的启动的接口为: Serve

代码如下:

func (s *Server) Serve(lis net.Listener) error {
	------- 非关键的代码先注释掉
	var tempDelay time.Duration // how long to sleep on accept failure

	for {
		rawConn, err := lis.Accept()
		if err != nil {
			if ne, ok := err.(interface {
				Temporary() bool
			}); ok && ne.Temporary() {
				if tempDelay == 0 {
					tempDelay = 5 * time.Millisecond
				} else {
					tempDelay *= 2
				}
				if max := 1 * time.Second; tempDelay > max {
					tempDelay = max
				}
				s.mu.Lock()
				s.printf("Accept error: %v; retrying in %v", err, tempDelay)
				s.mu.Unlock()
				timer := time.NewTimer(tempDelay)
				select {
				case <-timer.C:
				case <-s.quit.Done():
					timer.Stop()
					return nil
				}
				continue
			}
			s.mu.Lock()
			s.printf("done serving; Accept = %v", err)
			s.mu.Unlock()

			if s.quit.HasFired() {
				return nil
			}
			return err
		}
		tempDelay = 0
		// Start a new goroutine to deal with rawConn so we don't stall this Accept
		// loop goroutine.
		//
		// Make sure we account for the goroutine so GracefulStop doesn't nil out
		// s.conns before this conn can be added.
		s.serveWG.Add(1)
		go func() {
			s.handleRawConn(rawConn)
			s.serveWG.Done()
		}()
	}
}

可以看到大致逻辑是在一个for循环中,监听注册的tcp端口,如果有请求过来了,如果接受tcp请求时,发生err, 则休眠一段时间(刚开始休眠5ms, 然后不停翻倍,最多休眠1s), 则在一个协程中进行处理,同时使用waitgroup进行计数(这个跟优雅重启有关),所以核心处理逻辑是最后的handleRawConn方法

func (s *Server) handleRawConn(rawConn net.Conn) {
	if s.quit.HasFired() {
		rawConn.Close()
		return
	}
	rawConn.SetDeadline(time.Now().Add(s.opts.connectionTimeout))
	conn, authInfo, err := s.useTransportAuthenticator(rawConn)
	if err != nil {
		// ErrConnDispatched means that the connection was dispatched away from
		// gRPC; those connections should be left open.
		if err != credentials.ErrConnDispatched {
			s.mu.Lock()
			s.errorf("ServerHandshake(%q) failed: %v", rawConn.RemoteAddr(), err)
			s.mu.Unlock()
			channelz.Warningf(s.channelzID, "grpc: Server.Serve failed to complete security handshake from %q: %v", rawConn.RemoteAddr(), err)
			rawConn.Close()
		}
		rawConn.SetDeadline(time.Time{})
		return
	}

	// Finish handshaking (HTTP2)
	st := s.newHTTP2Transport(conn, authInfo)
	if st == nil {
		return
	}

	rawConn.SetDeadline(time.Time{})
	if !s.addConn(st) {
		return
	}
	go func() {
		s.serveStreams(st)
		s.removeConn(st)
	}()
}

可以看到handleRawConn方法中,先进行auth校验(这是grpc的特性,只不过一般的项目中未必会用到), 然后开启了一个http2连接st,

newHTTP2Transport就是建立一个http2的Transport,抓取的wireshark包如下图所示

在这里插入图片描述

看下面用到的st,最关键的代码是新开启的协程里的s.serveStreams(st),看里面的代码基本很明朗了:

func (s *Server) handleStream(t transport.ServerTransport, stream *transport.Stream, trInfo *traceInfo) {
	sm := stream.Method()
	if sm != "" && sm[0] == '/' {
		sm = sm[1:]
	}
	pos := strings.LastIndex(sm, "/")
	if pos == -1 {
		if trInfo != nil {
			trInfo.tr.LazyLog(&fmtStringer{"Malformed method name %q", []interface{}{sm}}, true)
			trInfo.tr.SetError()
		}
		errDesc := fmt.Sprintf("malformed method name: %q", stream.Method())
		if err := t.WriteStatus(stream, status.New(codes.ResourceExhausted, errDesc)); err != nil {
			if trInfo != nil {
				trInfo.tr.LazyLog(&fmtStringer{"%v", []interface{}{err}}, true)
				trInfo.tr.SetError()
			}
			channelz.Warningf(s.channelzID, "grpc: Server.handleStream failed to write status: %v", err)
		}
		if trInfo != nil {
			trInfo.tr.Finish()
		}
		return
	}
	service := sm[:pos]
	method := sm[pos+1:]

	srv, knownService := s.m[service]
	if knownService {
		if md, ok := srv.md[method]; ok {
			s.processUnaryRPC(t, stream, srv, md, trInfo)
			return
		}
		if sd, ok := srv.sd[method]; ok {
			s.processStreamingRPC(t, stream, srv, sd, trInfo)
			return
		}
	}
	// Unknown service, or known server unknown method.
	if unknownDesc := s.opts.unknownStreamDesc; unknownDesc != nil {
		s.processStreamingRPC(t, stream, nil, unknownDesc, trInfo)
		return
	}
	var errDesc string
	if !knownService {
		errDesc = fmt.Sprintf("unknown service %v", service)
	} else {
		errDesc = fmt.Sprintf("unknown method %v for service %v", method, service)
	}
	if trInfo != nil {
		trInfo.tr.LazyPrintf("%s", errDesc)
		trInfo.tr.SetError()
	}
	if err := t.WriteStatus(stream, status.New(codes.Unimplemented, errDesc)); err != nil {
		if trInfo != nil {
			trInfo.tr.LazyLog(&fmtStringer{"%v", []interface{}{err}}, true)
			trInfo.tr.SetError()
		}
		channelz.Warningf(s.channelzID, "grpc: Server.handleStream failed to write status: %v", err)
	}
	if trInfo != nil {
		trInfo.tr.Finish()
	}
}

grpc请求的路径是service/method,因此可以看到代码前面是找到/的位置,然后分割得出service和name名称,再通过上面已经讲到的grpc server结构体找到service和method,再调用s.processUnaryRPC(t, stream, srv, md, trInfo)或者s.processStreamingRPC(t, stream, srv, sd, trInfo)进行调用,我们只看比较简单的unary rpc处理接口,可以看到方法中终于有对handler进行的处理

ctx := NewContextWithServerTransportStream(stream.Context(), stream)
	reply, appErr := md.Handler(srv.server, ctx, df, s.opts.unaryInt)

以及返回信息

if err := s.sendResponse(t, stream, reply, cp, opts, comp); err != nil {

md.Handler就是调用创建server时创建的MethodDesc的handler,其中最后一个参数s.opts.unaryInt就是server端的interceptor,在第2步创建Server时塞给server结构体的

func NewServer(opt ...ServerOption) *Server {
	opts := defaultServerOptions
	for _, o := range opt {
		o.apply(&opts)
	}
	s := &Server{
		lis:    make(map[net.Listener]bool),
		opts:   opts,
		conns:  make(map[transport.ServerTransport]bool),
		m:      make(map[string]*service),
		quit:   grpcsync.NewEvent(),
		done:   grpcsync.NewEvent(),
		czData: new(channelzData),
	}
	chainUnaryServerInterceptors(s)
	chainStreamServerInterceptors(s)
	s.cv = sync.NewCond(&s.mu)
	if EnableTracing {
		_, file, line, _ := runtime.Caller(1)
		s.events = trace.NewEventLog("grpc.Server", fmt.Sprintf("%s:%d", file, line))
	}

	if channelz.IsOn() {
		s.channelzID = channelz.RegisterServer(&channelzServer{s}, "")
	}
	return s
}

每个method name对应一个handler接口,handler的接口逻辑如下:

func _UserService_Login_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
	in := new(LoginRequest)
	if err := dec(in); err != nil {
		return nil, err
	}
	if interceptor == nil {
		return srv.(UserServiceServer).Login(ctx, in)
	}
	info := &grpc.UnaryServerInfo{
		Server:     srv,
		FullMethod: "/service.UserService/Login",
	}
	handler := func(ctx context.Context, req interface{}) (interface{}, error) {
		return srv.(UserServiceServer).Login(ctx, req.(*LoginRequest))
	}
	return interceptor(ctx, in, info, handler)
}

可以看到,如果没有设置拦截器,则直接调用服务中实现的接口进行返回,如果是,则调用拦截器方法进行处和返回

sendResponse代码如下:

即先进行encode,默认是proto, 也可以自定义,然后进行compress, 最后再回写给client

func (s *Server) sendResponse(t transport.ServerTransport, stream *transport.Stream, msg interface{}, cp Compressor, opts *transport.Options, comp encoding.Compressor) error {
	data, err := encode(s.getCodec(stream.ContentSubtype()), msg)
	if err != nil {
		channelz.Error(s.channelzID, "grpc: server failed to encode response: ", err)
		return err
	}
	compData, err := compress(data, cp, comp)
	if err != nil {
		channelz.Error(s.channelzID, "grpc: server failed to compress response: ", err)
		return err
	}
	hdr, payload := msgHeader(data, compData)
	// TODO(dfawley): should we be checking len(data) instead?
	if len(payload) > s.opts.maxSendMessageSize {
		return status.Errorf(codes.ResourceExhausted, "grpc: trying to send message larger than max (%d vs. %d)", len(payload), s.opts.maxSendMessageSize)
	}
	err = t.Write(stream, hdr, payload, opts)
	if err == nil && s.opts.statsHandler != nil {
		s.opts.statsHandler.HandleRPC(stream.Context(), outPayload(false, msg, data, payload, time.Now()))
	}
	return err
}

总结

server端整体概括,就是监听端口,然后将受到的请求转发,根据method name转发到对应的handler进行处理

client端处理流程

client端发起grpc请求代码如下

1、建立连接

conn, err := grpc.Dial(address, grpc.WithInsecure(), grpc.WithBlock())

Dial方法返回clientConn

func Dial(target string, opts ...DialOption) (*ClientConn, error) {
	return DialContext(context.Background(), target, opts...)
}

建立连接的核心代码为DialContext方法

func DialContext(ctx context.Context, target string, opts ...DialOption) (conn *ClientConn, err error) {
	cc := &ClientConn{
		target:            target,
		csMgr:             &connectivityStateManager{},
		conns:             make(map[*addrConn]struct{}),
		dopts:             defaultDialOptions(),
		blockingpicker:    newPickerWrapper(),
		czData:            new(channelzData),
		firstResolveEvent: grpcsync.NewEvent(),
	}
	cc.retryThrottler.Store((*retryThrottler)(nil))
	cc.ctx, cc.cancel = context.WithCancel(context.Background())

	for _, opt := range opts {
		opt.apply(&cc.dopts)
	}

	chainUnaryClientInterceptors(cc)
	chainStreamClientInterceptors(cc)

	defer func() {
		if err != nil {
			cc.Close()
		}
	}()

	if channelz.IsOn() {
		if cc.dopts.channelzParentID != 0 {
			cc.channelzID = channelz.RegisterChannel(&channelzChannel{cc}, cc.dopts.channelzParentID, target)
			channelz.AddTraceEvent(cc.channelzID, 0, &channelz.TraceEventDesc{
				Desc:     "Channel Created",
				Severity: channelz.CtINFO,
				Parent: &channelz.TraceEventDesc{
					Desc:     fmt.Sprintf("Nested Channel(id:%d) created", cc.channelzID),
					Severity: channelz.CtINFO,
				},
			})
		} else {
			cc.channelzID = channelz.RegisterChannel(&channelzChannel{cc}, 0, target)
			channelz.Info(cc.channelzID, "Channel Created")
		}
		cc.csMgr.channelzID = cc.channelzID
	}

	if !cc.dopts.insecure {
		if cc.dopts.copts.TransportCredentials == nil && cc.dopts.copts.CredsBundle == nil {
			return nil, errNoTransportSecurity
		}
		if cc.dopts.copts.TransportCredentials != nil && cc.dopts.copts.CredsBundle != nil {
			return nil, errTransportCredsAndBundle
		}
	} else {
		if cc.dopts.copts.TransportCredentials != nil || cc.dopts.copts.CredsBundle != nil {
			return nil, errCredentialsConflict
		}
		for _, cd := range cc.dopts.copts.PerRPCCredentials {
			if cd.RequireTransportSecurity() {
				return nil, errTransportCredentialsMissing
			}
		}
	}

	if cc.dopts.defaultServiceConfigRawJSON != nil {
		scpr := parseServiceConfig(*cc.dopts.defaultServiceConfigRawJSON)
		if scpr.Err != nil {
			return nil, fmt.Errorf("%s: %v", invalidDefaultServiceConfigErrPrefix, scpr.Err)
		}
		cc.dopts.defaultServiceConfig, _ = scpr.Config.(*ServiceConfig)
	}
	cc.mkp = cc.dopts.copts.KeepaliveParams

	if cc.dopts.copts.Dialer == nil {
		cc.dopts.copts.Dialer = func(ctx context.Context, addr string) (net.Conn, error) {
			network, addr := parseDialTarget(addr)
			return (&net.Dialer{}).DialContext(ctx, network, addr)
		}
		if cc.dopts.withProxy {
			cc.dopts.copts.Dialer = newProxyDialer(cc.dopts.copts.Dialer)
		}
	}

	if cc.dopts.copts.UserAgent != "" {
		cc.dopts.copts.UserAgent += " " + grpcUA
	} else {
		cc.dopts.copts.UserAgent = grpcUA
	}

	if cc.dopts.timeout > 0 {
		var cancel context.CancelFunc
		ctx, cancel = context.WithTimeout(ctx, cc.dopts.timeout)
		defer cancel()
	}
	defer func() {
		select {
		case <-ctx.Done():
			conn, err = nil, ctx.Err()
		default:
		}
	}()

	scSet := false
	if cc.dopts.scChan != nil {
		// Try to get an initial service config.
		select {
		case sc, ok := <-cc.dopts.scChan:
			if ok {
				cc.sc = &sc
				scSet = true
			}
		default:
		}
	}
	if cc.dopts.bs == nil {
		cc.dopts.bs = backoff.DefaultExponential
	}

	// Determine the resolver to use.
	cc.parsedTarget = grpcutil.ParseTarget(cc.target)
	channelz.Infof(cc.channelzID, "parsed scheme: %q", cc.parsedTarget.Scheme)
	resolverBuilder := cc.getResolver(cc.parsedTarget.Scheme)
	if resolverBuilder == nil {
		// If resolver builder is still nil, the parsed target's scheme is
		// not registered. Fallback to default resolver and set Endpoint to
		// the original target.
		channelz.Infof(cc.channelzID, "scheme %q not registered, fallback to default scheme", cc.parsedTarget.Scheme)
		cc.parsedTarget = resolver.Target{
			Scheme:   resolver.GetDefaultScheme(),
			Endpoint: target,
		}
		resolverBuilder = cc.getResolver(cc.parsedTarget.Scheme)
		if resolverBuilder == nil {
			return nil, fmt.Errorf("could not get resolver for default scheme: %q", cc.parsedTarget.Scheme)
		}
	}

	creds := cc.dopts.copts.TransportCredentials
	if creds != nil && creds.Info().ServerName != "" {
		cc.authority = creds.Info().ServerName
	} else if cc.dopts.insecure && cc.dopts.authority != "" {
		cc.authority = cc.dopts.authority
	} else {
		// Use endpoint from "scheme://authority/endpoint" as the default
		// authority for ClientConn.
		cc.authority = cc.parsedTarget.Endpoint
	}

	if cc.dopts.scChan != nil && !scSet {
		// Blocking wait for the initial service config.
		select {
		case sc, ok := <-cc.dopts.scChan:
			if ok {
				cc.sc = &sc
			}
		case <-ctx.Done():
			return nil, ctx.Err()
		}
	}
	if cc.dopts.scChan != nil {
		go cc.scWatcher()
	}

	var credsClone credentials.TransportCredentials
	if creds := cc.dopts.copts.TransportCredentials; creds != nil {
		credsClone = creds.Clone()
	}
	cc.balancerBuildOpts = balancer.BuildOptions{
		DialCreds:        credsClone,
		CredsBundle:      cc.dopts.copts.CredsBundle,
		Dialer:           cc.dopts.copts.Dialer,
		ChannelzParentID: cc.channelzID,
		Target:           cc.parsedTarget,
	}

	// Build the resolver.
	rWrapper, err := newCCResolverWrapper(cc, resolverBuilder)
	if err != nil {
		return nil, fmt.Errorf("failed to build resolver: %v", err)
	}
	cc.mu.Lock()
	cc.resolverWrapper = rWrapper
	cc.mu.Unlock()

	// A blocking dial blocks until the clientConn is ready.
	if cc.dopts.block {
		for {
			s := cc.GetState()
			if s == connectivity.Ready {
				break
			} else if cc.dopts.copts.FailOnNonTempDialError && s == connectivity.TransientFailure {
				if err = cc.blockingpicker.connectionError(); err != nil {
					terr, ok := err.(interface {
						Temporary() bool
					})
					if ok && !terr.Temporary() {
						return nil, err
					}
				}
			}
			if !cc.WaitForStateChange(ctx, s) {
				// ctx got timeout or canceled.
				return nil, ctx.Err()
			}
		}
	}

	return cc, nil
}

DialContext核心逻辑如下:

1.1、初始化ClientConn结构体

可以看到初始化了ClientConn结构体

cc := &ClientConn{
		target:            target,
		csMgr:             &connectivityStateManager{},
		conns:             make(map[*addrConn]struct{}),
		dopts:             defaultDialOptions(),
		blockingpicker:    newPickerWrapper(),
		czData:            new(channelzData),
		firstResolveEvent: grpcsync.NewEvent(),
	}

target是ip/port, dopts是一些参数,比如常用的WithInsecure, WithBlock

下面的代码就是将Dial时添加的opts添加到defaultOptios中,其实本质上是通过这些opts方法设置options的一些属性

for _, opt := range opts {
   opt.apply(&cc.dopts)
}

看defaultDialOptions方法如下:

type dialOptions struct {
	unaryInt  UnaryClientInterceptor
	streamInt StreamClientInterceptor

	chainUnaryInts  []UnaryClientInterceptor
	chainStreamInts []StreamClientInterceptor

	cp          Compressor
	dc          Decompressor
	bs          internalbackoff.Strategy
	block       bool
	insecure    bool
	timeout     time.Duration
	scChan      <-chan ServiceConfig
	authority   string
	copts       transport.ConnectOptions
	callOptions []CallOption
	// This is used by v1 balancer dial option WithBalancer to support v1
	// balancer, and also by WithBalancerName dial option.
	balancerBuilder             balancer.Builder
	channelzParentID            int64
	disableServiceConfig        bool
	disableRetry                bool
	disableHealthCheck          bool
	healthCheckFunc             internal.HealthChecker
	minConnectTimeout           func() time.Duration
	defaultServiceConfig        *ServiceConfig // defaultServiceConfig is parsed from defaultServiceConfigRawJSON.
	defaultServiceConfigRawJSON *string
	// This is used by ccResolverWrapper to backoff between successive calls to
	// resolver.ResolveNow(). The user will have no need to configure this, but
	// we need to be able to configure this in tests.
	resolveNowBackoff func(int) time.Duration
	resolvers         []resolver.Builder
	withProxy         bool
}

可以看到block,insecure等参数,WithInsecure这些方法,最终是给defaultOptions的对应的参数赋值

另外,还初始化了connectivityStateManager以及newPickerWrapper参数,

connectivityStateManager是指连接状态管理,每个连接具有 “IDLE”、“CONNECTING”、“READY”、“TRANSIENT_FAILURE”、“SHUTDOWN”、“Invalid-State” 这几种状态

newPickerWrapper跟负载均衡有关,里面封装subConn,表示一个server多个实例的address列表,具体逻辑需要再进行深度探讨

1.2、接口构建拦截器

chainUnaryClientInterceptors构建拦截器

func chainUnaryClientInterceptors(cc *ClientConn) {
	interceptors := cc.dopts.chainUnaryInts
	// Prepend dopts.unaryInt to the chaining interceptors if it exists, since unaryInt will
	// be executed before any other chained interceptors.
	if cc.dopts.unaryInt != nil {
		interceptors = append([]UnaryClientInterceptor{cc.dopts.unaryInt}, interceptors...)
	}
	var chainedInt UnaryClientInterceptor
	if len(interceptors) == 0 {
		chainedInt = nil
	} else if len(interceptors) == 1 {
		chainedInt = interceptors[0]
	} else {
		chainedInt = func(ctx context.Context, method string, req, reply interface{}, cc *ClientConn, invoker UnaryInvoker, opts ...CallOption) error {
			return interceptors[0](ctx, method, req, reply, cc, getChainUnaryInvoker(interceptors, 0, invoker), opts...)
		}
	}
	cc.dopts.unaryInt = chainedInt
}

可以可做把所有的拦截器合并成了一个UnaryClientInterceptor结构体,并且这个结构体采用的是责任链设计模式(cc.dopts.unaryInt基本取得是第一个拦截器,并且设置第一个拦截器的invoker方法为第二个拦截器,也就是执行第一个拦截器的逻辑时,会再执行第二个拦截器的方法,依次往后类推,所以是责任链设计模式)

1.3、根据option参数进行一些逻辑处理

根据options的一些参数,比如timeout, autority等设置超时,权限验证等参数

2、根据创建的conn创建client

userServiceClient := pb.NewUserServiceClient(conn)

NewUserServiceClient是pb生成的go文件中的方法,其实就是封装了conn而已

3、rpc调用

func (c *userServiceClient) Login(ctx context.Context, in *LoginRequest, opts ...grpc.CallOption) (*UserInfoDTO, error) {
	out := new(UserInfoDTO)
	err := c.cc.Invoke(ctx, "/service.UserService/login", in, out, opts...)
	if err != nil {
		return nil, err
	}
	return out, nil
}

通过调用cc的Invoke方法进行处理,这个cc就是client封装的conn

Invoke方法代码为

func (cc *ClientConn) Invoke(ctx context.Context, method string, args, reply interface{}, opts ...CallOption) error {
	// allow interceptor to see all applicable call options, which means those
	// configured as defaults from dial option as well as per-call options
	opts = combine(cc.dopts.callOptions, opts)

	if cc.dopts.unaryInt != nil {
		return cc.dopts.unaryInt(ctx, method, args, reply, cc, invoke, opts...)
	}
	return invoke(ctx, method, args, reply, cc, opts...)
}

再看最后的invoke方法:

func invoke(ctx context.Context, method string, req, reply interface{}, cc *ClientConn, opts ...CallOption) error {
	cs, err := newClientStream(ctx, unaryStreamDesc, cc, method, opts...)
	if err != nil {
		return err
	}
	if err := cs.SendMsg(req); err != nil {
		return err
	}
	return cs.RecvMsg(reply)
}

可以看到有个SendMsg和RecvMsg

3.1 SendMsg

SendMsg代码如下:

func (cs *clientStream) SendMsg(m interface{}) (err error) {
	defer func() {
		if err != nil && err != io.EOF {
			// Call finish on the client stream for errors generated by this SendMsg
			// call, as these indicate problems created by this client.  (Transport
			// errors are converted to an io.EOF error in csAttempt.sendMsg; the real
			// error will be returned from RecvMsg eventually in that case, or be
			// retried.)
			cs.finish(err)
		}
	}()
	if cs.sentLast {
		return status.Errorf(codes.Internal, "SendMsg called after CloseSend")
	}
	if !cs.desc.ClientStreams {
		cs.sentLast = true
	}

	// load hdr, payload, data
	hdr, payload, data, err := prepareMsg(m, cs.codec, cs.cp, cs.comp)
	if err != nil {
		return err
	}

	// TODO(dfawley): should we be checking len(data) instead?
	if len(payload) > *cs.callInfo.maxSendMessageSize {
		return status.Errorf(codes.ResourceExhausted, "trying to send message larger than max (%d vs. %d)", len(payload), *cs.callInfo.maxSendMessageSize)
	}
	msgBytes := data // Store the pointer before setting to nil. For binary logging.
	op := func(a *csAttempt) error {
		err := a.sendMsg(m, hdr, payload, data)
		// nil out the message and uncomp when replaying; they are only needed for
		// stats which is disabled for subsequent attempts.
		m, data = nil, nil
		return err
	}
	err = cs.withRetry(op, func() { cs.bufferForRetryLocked(len(hdr)+len(payload), op) })
	if cs.binlog != nil && err == nil {
		cs.binlog.Log(&binarylog.ClientMessage{
			OnClientSide: true,
			Message:      msgBytes,
		})
	}
	return
}

先通过prepareMsg准备数据,再通过csAttempt的sendMsg发送数据

func (a *csAttempt) sendMsg(m interface{}, hdr, payld, data []byte) error {
	cs := a.cs
	if a.trInfo != nil {
		a.mu.Lock()
		if a.trInfo.tr != nil {
			a.trInfo.tr.LazyLog(&payload{sent: true, msg: m}, true)
		}
		a.mu.Unlock()
	}
	if err := a.t.Write(a.s, hdr, payld, &transport.Options{Last: !cs.desc.ClientStreams}); err != nil {
		if !cs.desc.ClientStreams {
			// For non-client-streaming RPCs, we return nil instead of EOF on error
			// because the generated code requires it.  finish is not called; RecvMsg()
			// will call it with the stream's status independently.
			return nil
		}
		return io.EOF
	}
	if a.statsHandler != nil {
		a.statsHandler.HandleRPC(cs.ctx, outPayload(true, m, data, payld, time.Now()))
	}
	if channelz.IsOn() {
		a.t.IncrMsgSent()
	}
	return nil
}

a.t中这个t是Transport

3.2 RecvMsg

func (cs *clientStream) RecvMsg(m interface{}) error {
	if cs.binlog != nil && !cs.serverHeaderBinlogged {
		// Call Header() to binary log header if it's not already logged.
		cs.Header()
	}
	var recvInfo *payloadInfo
	if cs.binlog != nil {
		recvInfo = &payloadInfo{}
	}
	err := cs.withRetry(func(a *csAttempt) error {
		return a.recvMsg(m, recvInfo)
	}, cs.commitAttemptLocked)
	if cs.binlog != nil && err == nil {
		cs.binlog.Log(&binarylog.ServerMessage{
			OnClientSide: true,
			Message:      recvInfo.uncompressedBytes,
		})
	}
	if err != nil || !cs.desc.ServerStreams {
		// err != nil or non-server-streaming indicates end of stream.
		cs.finish(err)

		if cs.binlog != nil {
			// finish will not log Trailer. Log Trailer here.
			logEntry := &binarylog.ServerTrailer{
				OnClientSide: true,
				Trailer:      cs.Trailer(),
				Err:          err,
			}
			if logEntry.Err == io.EOF {
				logEntry.Err = nil
			}
			if peer, ok := peer.FromContext(cs.Context()); ok {
				logEntry.PeerAddr = peer.Addr
			}
			cs.binlog.Log(logEntry)
		}
	}
	return err
}

也是会调用csAttempt的recvMsg方法,然后经过层层调用会到达parser的recvMsg方法

func (p *parser) recvMsg(maxReceiveMessageSize int) (pf payloadFormat, msg []byte, err error) {
	if _, err := p.r.Read(p.header[:]); err != nil {
		return 0, nil, err
	}

	pf = payloadFormat(p.header[0])
	length := binary.BigEndian.Uint32(p.header[1:])

	if length == 0 {
		return pf, nil, nil
	}
	if int64(length) > int64(maxInt) {
		return 0, nil, status.Errorf(codes.ResourceExhausted, "grpc: received message larger than max length allowed on current machine (%d vs. %d)", length, maxInt)
	}
	if int(length) > maxReceiveMessageSize {
		return 0, nil, status.Errorf(codes.ResourceExhausted, "grpc: received message larger than max (%d vs. %d)", length, maxReceiveMessageSize)
	}
	// TODO(bradfitz,zhaoq): garbage. reuse buffer after proto decoding instead
	// of making it for each message:
	msg = make([]byte, int(length))
	if _, err := p.r.Read(msg); err != nil {
		if err == io.EOF {
			err = io.ErrUnexpectedEOF
		}
		return 0, nil, err
	}
	return pf, msg, nil
}

可以看到方法的第一行Read函数,r类型是io.Reader,这个形式跟golang/java发起http调用然后读取返回参数是一致的

总结

本文通过对调用流程进行追踪,总体梳理了一下grpc的源码,大致了解了grpc内部的运行机制,但是很多细节以及其他的特性其实没有深究,本次可以看做是grpc的源码入门。 后续有时间,会再度对grpc的代码框架、网络连接等源码进行深究

  • 2
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
gRPC是一款开源的高性能远程过程调用(RPC)框架,由Google开发并开源。它基于HTTP/2和Protocol Buffers来实现跨平台、多语言的远程方法调用。grpc-go是gRPC的Go语言实现。 首先,我们来分析grpc-go的源码结构。grpc的核心代码位于grpc-go目录下,包括server、client、metadata等模块的代码实现。其中,server目录下的代码主要负责服务端的初始化、启动和处理请求;client目录下的代码则主要负责客户端的连接和发送请求;metadata目录下的代码保存了gRPC使用的元数据信息。 接着,我们来看一下grpc-go的基本工作流程。在服务端,首先要创建一个grpc.Server对象,然后通过调用Server的RegisterService方法注册一个或多个服务;然后通过调用Server的Serve方法启动服务。在客户端,首先要建立与服务端的连接,通过调用grpc.Dial方法创建一个grpc.ClientConn对象;然后通过该对象创建一个或多个服务的Client对象,最后通过Client对象调用远程方法。 grpc-go的底层代码主要依赖于Go语言的标准库和一些第三方库。其中,标准库主要包括net、http、io等模块;第三方库主要包括golang/protobuf、google.golang.org/grpc等。grpc-go通过protobuf编译器生成的代码来对消息进行序列化和反序列化,利用HTTP/2协议的多路复用特性来提高通信效率。 grpc-go的源码解析还涉及一些高级特性,如流式RPC、拦截器、错误处理等。流式RPC可以实现客户端和服务端之间的双向流式通信,通过使用流来传输大量的数据。拦截器可以用于对请求和响应进行预处理和后处理,对于日志记录、认证、鉴权等方面非常有用。错误处理可以帮助程序员更好地处理可能发生的异常情况,提高代码的可靠性。 总而言之,grpc-go的源码解析涉及了很多基础知识和高级特性,需要深入理解和掌握。通过对grpc-go源码的分析,我们可以更好地理解它的工作原理,从而能够更好地使用和扩展该框架。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值