前言
前面使用了grpc进行客户端和服务端之间的数据传输。客户端每次使用前都需要先Dial,使用完之后直接就Close掉了,下一次请求进来又重新Dial,这样资源消耗十分严重,于是将rfyiamcool写的连接池改了一点点,实现连接的复用
先上对比
go test -bench=. -run=none
goos: linux
goarch: amd64
pkg: client
BenchmarkRpc 1462 805166 ns/op
PASS
ok client 1.267s
go test -bench=. -run=none
goos: linux
goarch: amd64
pkg: client
BenchmarkRpc 5713 220720 ns/op
PASS
ok client 2.272s
可以看到执行速度提升了接近4倍
代码
ServicePool中的clients是以target为键,ClientPool类型为值组成的map。 同一个target的service(s)由同1个ClientPool来维护,1个clientpool里又有多个conn可以使用,采用取余的方式来随机选用(可防止超出clientpool)
package common
import (
"context"
"errors"
"google.golang.org/grpc"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/keepalive"
"google.golang.org/grpc/metadata"
"strings"
"sync"
"sync/atomic"
"time"
)
var (
ErrNotFoundClient = errors.New("not found grpc conn")
ErrConnShutdown = errors.New("grpc conn shutdown")
defaultClientPoolConnsSizeCap = 5
defaultDialTimeout = 5 * time.Second
defaultKeepAlive = 30 * time.Second
defaultKeepAliveTimeout = 10 * time.Second
)
type ClientOption struct {
ClientPoolConnsSize int
DialTimeOut time.Duration
KeepAlive time.Duration
KeepAliveTimeout time.Duration
}
type ClientPool struct {
target string
option *ClientOption
next int64
cap int64
sync.Mutex
conns []*grpc.ClientConn
}
func (cc *ClientPool) getConn() (*grpc.ClientConn, error){
var (
idx int64
next int64
err error
)
next = atomic.AddInt64(&cc.next, 1)
idx = next % cc.cap
conn := cc.conns[idx]
if conn != nil && cc.checkState(conn) == nil {
return conn, nil
}
//gc old conn
if conn != nil {
conn.Close()
}
cc.Lock()
defer cc.Unlock()
//double check, Prevent have been initialized
if conn != nil && cc.checkState(conn) == nil {
return conn, nil
}
conn, err = cc.connect()
if err != nil {
return nil, err
}
cc.conns[idx] = conn
return conn, nil
}
func (cc *ClientPool) checkState(conn *grpc.ClientConn) error {
state := conn.GetState()
switch state {
case connectivity.TransientFailure, connectivity.Shutdown:
return ErrConnShutdown
}
return nil
}
func (cc *ClientPool) connect() (*grpc.ClientConn, error) {
ctx, cal := context.WithTimeout(context.TODO(), cc.option.DialTimeOut)
defer cal()
conn, err := grpc.DialContext(ctx,
cc.target,
grpc.WithInsecure(),
grpc.WithBlock(),
grpc.WithKeepaliveParams(keepalive.ClientParameters{
Time: cc.option.KeepAlive,
Timeout: cc.option.KeepAliveTimeout,
}))
if err != nil {
return nil, err
}
return conn, nil
}
func (cc *ClientPool) Close() {
cc.Lock()
defer cc.Unlock()
for _, conn := range cc.conns {
if conn == nil {
continue
}
conn.Close()
}
}
func NewClientPoolWithOption(target string, option *ClientOption) *ClientPool {
if (option.ClientPoolConnsSize) <= 0 {
option.ClientPoolConnsSize = defaultClientPoolConnsSizeCap
}
if option.DialTimeOut <= 0 {
option.DialTimeOut = defaultDialTimeout
}
if option.KeepAlive <= 0 {
option.KeepAlive = defaultKeepAlive
}
if option.KeepAliveTimeout <= 0 {
option.KeepAliveTimeout = defaultKeepAliveTimeout
}
return &ClientPool{
target: target,
option: option,
cap: int64(option.ClientPoolConnsSize),
conns: make([]*grpc.ClientConn, option.ClientPoolConnsSize) ,
}
}
type TargetServiceNames struct {
m map[string][]string
}
func NewTargetServiceNames() *TargetServiceNames {
return &TargetServiceNames{
m: make(map[string][]string),
}
}
func (h *TargetServiceNames) Set(target string, serviceNames ...string) {
if len(serviceNames) <= 0 {
return
}
soureServNames := h.m[target]
for _, sn := range serviceNames {
soureServNames = append(soureServNames, sn)
}
h.m[target] = soureServNames
}
func (h *TargetServiceNames) list() map[string][]string {
return h.m
}
func (h *TargetServiceNames) len() int {
return len(h.m)
}
//通过属性clients以服务名为key去map里取ClientPool连接池里的clientconn
type ServiceClientPool struct {
clients map[string]*ClientPool
option *ClientOption
clientCap int
}
func NewServiceClientPool(option *ClientOption) *ServiceClientPool {
return &ServiceClientPool{
option: option,
clientCap: option.ClientPoolConnsSize,
}
}
func (sc *ServiceClientPool) Init(m *TargetServiceNames) {
var clients = make(map[string]*ClientPool, m.len())
for target, servNameArr := range m.list() {
cc := NewClientPoolWithOption(target, sc.option)
for _, srv := range servNameArr {
clients[srv] = cc
}
}
sc.clients = clients
}
func (sc *ServiceClientPool) GetClientWithFullMethod(fullMethod string) (*grpc.ClientConn, error){
sn := sc.SpiltFullMethod(fullMethod)
return sc.GetClient(sn)
}
func (sc *ServiceClientPool) GetClient(sname string) (*grpc.ClientConn, error) {
cc, ok := sc.clients[sname]
if !ok {
return nil, ErrNotFoundClient
}
return cc.getConn()
}
func (sc *ServiceClientPool) Close(sname string) {
cc, ok := sc.clients[sname]
if !ok {
return
}
cc.Close()
}
func (sc *ServiceClientPool) CloseAll() {
for _, client := range sc.clients {
client.Close()
}
}
func (sc *ServiceClientPool) SpiltFullMethod(fullMethod string) string {
var arr []string
arr = strings.Split(fullMethod, "/")
if len(arr) != 3 {
return ""
}
return arr[1]
}
func (sc *ServiceClientPool) Invoke(ctx context.Context, fullMethod string, headers map[string]string, args interface{}, reply interface{}, opts ...grpc.CallOption) error {
var md metadata.MD
sname := sc.SpiltFullMethod(fullMethod)
conn, err := sc.GetClient(sname)
if err != nil {
return err
}
md, flag := metadata.FromOutgoingContext(ctx)
if flag == true {
md = md.Copy()
} else {
md = metadata.MD{}
}
for k, v := range headers {
md.Set(k, v)
}
ctx = metadata.NewOutgoingContext(ctx, md)
return conn.Invoke(ctx, fullMethod, args, reply, opts...)
}
const (
ADDRESS = "127.0.0.1:7778"
SERVICENAME = "hello.FirstService"
)
var scp *ServiceClientPool
func init() {
co := ClientOption{
ClientPoolConnsSize: defaultClientPoolConnsSizeCap,
DialTimeOut: defaultDialTimeout,
KeepAlive: defaultKeepAlive,
KeepAliveTimeout: defaultKeepAliveTimeout,
}
scp = NewServiceClientPool(&co)
tsn := NewTargetServiceNames()
tsn.Set(ADDRESS, SERVICENAME)
scp.Init(tsn)
}
func GetScp() *ServiceClientPool{
return scp
}
使用示例
clientPool := common.GetScp()
reply := &pb.HelloReply{}
err := clientPool.Invoke(context.Background(), "/hello.FirstService/SayHello", nil, &pb.HelloRequest{Name: "lubenwei", Age: "21"}, reply)
fmt.Println("耗时:", time.Since(startTime))
if err != nil {
fmt.Println("超时:", err)
return
}
fmt.Println(reply.Time)