代码地址
https://github.com/wanmei002/grpc-learn/tree/master/ch05
拦截器简述
在远程方法执行之前或执行之后都需要做一些通用逻辑。gRPC有拦截器相关逻辑(相当于 gin 框架的 Use)。或做一些日志、身份验证、性能等需求,可以在拦截器里实现。
服务端拦截器
一元拦截器
一元拦截器需要实现 grpc/interceptor/type UnaryServerInterceptor func(ctx context.Context, req interface{}, info *UnaryServerInfo, handler UnaryHandler) (resp interface{}, err error)
里的这个方法。
// 一元拦截器
func orderUnaryServerInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo,
handler grpc.UnaryHandler) (resp interface{}, err error) {
// 前置处理逻辑
typeO := reflect.TypeOf(req)
log.Printf("typeof name:%v, type of kind:%v;\n", typeO.Name(), typeO.Kind())
if _, ok := req.(*pb.Order); ok {
fmt.Println("req belong order type")
}
// 上面的都是前置处理逻辑
m, err := handler(ctx, req)
// 下面的都是后置处理逻辑
if err != nil {
log.Panicln("handler 处理的返回error:", err)
}
log.Printf("hander ret:%+v\n", m)
return m, err
}
服务端流拦截器
服务端流拦截器需要实现 grpc/interceptor/type StreamServerInterceptor func(srv interface{}, ss ServerStream, info *StreamServerInfo, handler StreamHandler) error
这个方法
主要代码如下:
// 实现 grpc.ServerStream 这个接口
type serverStreamInterceptor struct {
grpc.ServerStream
}
// 重写 RecvMsg 这个方法,这个是流接收数据的前置逻辑
func (si *serverStreamInterceptor) RecvMsg(m interface{}) error {
log.Printf("=====[server stream interceptor recv msg] " +
"rece msg type[%T]=====\n", m)
return si.ServerStream.RecvMsg(m)
}
// 重写 SendMsg 这个方法,这个是返回数据的后置逻辑
func (si *serverStreamInterceptor) SendMsg(m interface{}) error {
log.Printf("===== server stream interceptor send msg " +
"send msg type[%T]=====\n", m)
return si.ServerStream.SendMsg(m)
}
func NewServerStreamInterceptor(s grpc.ServerStream) grpc.ServerStream {
return &serverStreamInterceptor{s}
}
// 实现流拦截器的方法
func ProductServerStreamInterceptor(srv interface{}, ss grpc.ServerStream,
info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
log.Printf("server stream interceptor , get data type[%T]; value : [%v]\n", srv, srv)
err := handler(srv, NewServerStreamInterceptor(ss))
if err != nil {
log.Println("server stream handler err:", err)
}
return err
}
入口方法
func main(){
port := ":8093"
ls, err := net.Listen("tcp", port)
if err != nil {
log.Println("listen failed; err:", err)
return
}
g := grpc.NewServer(grpc.UnaryInterceptor(orderUnaryServerInterceptor), grpc.StreamInterceptor(interceptor.ProductServerStreamInterceptor)) // 注册拦截器
pb.RegisterProductServer(g, &server{})
log.Println("server start, port :", port)
if err = g.Serve(ls); err != nil {
log.Println("grpc serve failed; err:", err)
}
}