go语言之thrift协议之client端分析
上一篇文章分析了thrift协议server端的实现,这边还是基于官方的示例去分析。
import (
"crypto/tls"
"flag"
"fmt"
"os"
"github.com/apache/thrift/lib/go/thrift"
)
func Usage() {
fmt.Fprint(os.Stderr, "Usage of ", os.Args[0], ":\n")
flag.PrintDefaults()
fmt.Fprint(os.Stderr, "\n")
}
func main() {
flag.Usage = Usage
server := flag.Bool("server", false, "Run server")
protocol := flag.String("P", "compact", "Specify the protocol (binary, compact, json, simplejson)")
framed := flag.Bool("framed", false, "Use framed transport")
buffered := flag.Bool("buffered", false, "Use buffered transport")
addr := flag.String("addr", "localhost:9090", "Address to listen to")
secure := flag.Bool("secure", false, "Use tls secure transport")
flag.Parse()
var protocolFactory thrift.TProtocolFactory
switch *protocol {
case "compact":
protocolFactory = thrift.NewTCompactProtocolFactoryConf(nil)
case "simplejson":
protocolFactory = thrift.NewTSimpleJSONProtocolFactoryConf(nil)
case "json":
protocolFactory = thrift.NewTJSONProtocolFactory()
case "binary", "":
protocolFactory = thrift.NewTBinaryProtocolFactoryConf(nil)
default:
fmt.Fprint(os.Stderr, "Invalid protocol specified", protocol, "\n")
Usage()
os.Exit(1)
}
var transportFactory thrift.TTransportFactory
cfg := &thrift.TConfiguration{
TLSConfig: &tls.Config{
InsecureSkipVerify: true,
},
}
if *buffered {
transportFactory = thrift.NewTBufferedTransportFactory(8192)
} else {
transportFactory = thrift.NewTTransportFactory()
}
if *framed {
transportFactory = thrift.NewTFramedTransportFactoryConf(transportFactory, cfg)
}
if *server {
if err := runServer(transportFactory, protocolFactory, *addr, *secure); err != nil {
fmt.Println("error running server:", err)
}
} else {
if err := runClient(transportFactory, protocolFactory, *addr, *secure, cfg); err != nil {
fmt.Println("error running client:", err)
}
}
}
这里的代码因为是和server端是公用的,所以就不进行具体的分析,详情可以看上一篇。
runClient
还是先看一下具体的实现
func runClient(transportFactory thrift.TTransportFactory, protocolFactory thrift.TProtocolFactory, addr string, secure bool, cfg *thrift.TConfiguration) error {
// 初始化transport
var transport thrift.TTransport
if secure {
transport = thrift.NewTSSLSocketConf(addr, cfg)
} else {
transport = thrift.NewTSocketConf(addr, cfg)
}
// 根据transportFactory获取transport
transport, err := transportFactory.GetTransport(transport)
if err != nil {
return err
}
defer transport.Close()
// 根据地址获取连接
if err := transport.Open(); err != nil {
return err
}
// 获取inbound的proto
iprot := protocolFactory.GetProtocol(transport)
// 获取inbound的proto
oprot := protocolFactory.GetProtocol(transport)
// 进行请求
return handleClient(tutorial.NewCalculatorClient(thrift.NewTStandardClient(iprot, oprot)))
}
首先就是初始化client的transport,这里因为没有走https,所以这里走的是transport = thrift.NewTSocketConf(addr, cfg).然后和server一样,其实最终是初始化了TSocket这个结构体,如下
// NewTSocketFromAddrConf creates a TSocket from a net.Addr
func NewTSocketFromAddrConf(addr net.Addr, conf *TConfiguration) *TSocket {
return &TSocket{
addr: addr,
cfg: conf,
}
}
然后调用transportFactory.GetTransport。这里的GetTransport其实还是原封不动的返回了,所以这里就是TSocket的这个结构体,然后就是open方法去根据地址获取连接。
Open
// Connects the socket, creating a new socket object if necessary.
func (p *TSocket) Open() error {
if p.conn.isValid() {
return NewTTransportException(ALREADY_OPEN, "Socket already connected.")
}
if p.addr == nil {
return NewTTransportException(NOT_OPEN, "Cannot open nil address.")
}
if len(p.addr.Network()) == 0 {
return NewTTransportException(NOT_OPEN, "Cannot open bad network name.")
}
if len(p.addr.String()) == 0 {
return NewTTransportException(NOT_OPEN, "Cannot open bad address.")
}
var err error
// 根据地址获取连接
if p.conn, err = createSocketConnFromReturn(net.DialTimeout(
p.addr.Network(),
p.addr.String(),
p.cfg.GetConnectTimeout(),
)); err != nil {
return &tTransportException{
typeId: NOT_OPEN,
err: err,
msg: err.Error(),
}
}
p.addr = p.conn.RemoteAddr()
return nil
}
这里的逻辑其实也是比较简单,通过官方的net.DialTimeout方法去获取一个链接。
protocolFactory.GetProtocol
这里的实现是和server端一样,根据compact的proto去返回一个TProtocol。
func (p *TCompactProtocolFactory) GetProtocol(trans TTransport) TProtocol {
return NewTCompactProtocolConf(trans, p.cfg)
}
func NewTCompactProtocolConf(trans TTransport, conf *TConfiguration) *TCompactProtocol {
PropagateTConfiguration(trans, conf)
p := &TCompactProtocol{
origTransport: trans,
cfg: conf,
}
if et, ok := trans.(TRichTransport); ok {
p.trans = et
} else {
p.trans = NewTRichTransport(trans)
}
return p
}
handleClient
注意这里首先是需要初始化client的方法也就是 tutorial.NewCalculatorClient(thrift.NewTStandardClient(iprot, oprot))。接下来进行逐步的解析
NewTStandardClient
// TStandardClient implements TClient, and uses the standard message format for Thrift.
// It is not safe for concurrent use.
func NewTStandardClient(inputProtocol, outputProtocol TProtocol) *TStandardClient {
return &TStandardClient{
iprot: inputProtocol,
oprot: outputProtocol,
}
}
type TClient interface {
Call(ctx context.Context, method string, args, result TStruct) (ResponseMeta, error)
}
首先需要注意的是这个TStandardClient结构体实现了TClient 这个interface,这个也是这个结构体核心的功能,在后面的很多地方也都有使用。然后看一下call这个方法
type TStruct interface {
Write(ctx context.Context, p TProtocol) error
Read(ctx context.Context, p TProtocol) error
}
func (p *TStandardClient) Call(ctx context.Context, method string, args, result TStruct) (ResponseMeta, error) {
p.seqId++
seqId := p.seqId
if err := p.Send(ctx, p.oprot, seqId, method, args); err != nil {
return ResponseMeta{}, err
}
// method is oneway
if result == nil {
return ResponseMeta{}, nil
}
err := p.Recv(ctx, p.iprot, seqId, method, result)
var headers THeaderMap
if hp, ok := p.iprot.(*THeaderProtocol); ok {
headers = hp.transport.readHeaders
}
return ResponseMeta{
Headers: headers,
}, err
}
这里的seqId每次调用都会自增一次。
然后可以看出来这个方法的主要是由Send 和Recv这两个方法构成。分别看一下Send方法和Recv方法,
func (p *TStandardClient) Send(ctx context.Context, oprot TProtocol, seqId int32, method string, args TStruct) error {
// Set headers from context object on THeaderProtocol
if headerProt, ok := oprot.(*THeaderProtocol); ok {
headerProt.ClearWriteHeaders()
for _, key := range GetWriteHeaderList(ctx) {
if value, ok := GetHeader(ctx, key); ok {
headerProt.SetWriteHeader(key, value)
}
}
}
// 注意这里的oprot就是compact所生成的TProtocol。
if err := oprot.WriteMessageBegin(ctx, method, CALL, seqId); err != nil {
return err
}
// 因为这里的args是一个interface,所以后面根据具体的实现分析
if err := args.Write(ctx, oprot); err != nil {
return err
}
// 这里compact的WriteMessageEnd方法
if err := oprot.WriteMessageEnd(ctx); err != nil {
return err
}
// 调用compact的Flush方法
return oprot.Flush(ctx)
}
这里没有多余的逻辑主要就是调用compact中的WriteMessageBegin,WriteMessageEnd方法,然后调用args的args方法。
然后看一下recv方法。
func (p *TStandardClient) Recv(ctx context.Context, iprot TProtocol, seqId int32, method string, result TStruct) error {
// 读取compact的ReadMessageBegin方法 并且获取对应的参数
rMethod, rTypeId, rSeqId, err := iprot.ReadMessageBegin(ctx)
if err != nil {
return err
}
// 判断是不是自己指定的方法
if method != rMethod {
return NewTApplicationException(WRONG_METHOD_NAME, fmt.Sprintf("%s: wrong method name", method))
} else if seqId != rSeqId {
return NewTApplicationException(BAD_SEQUENCE_ID, fmt.Sprintf("%s: out of order sequence response", method))
} else if rTypeId == EXCEPTION {
var exception tApplicationException
if err := exception.Read(ctx, iprot); err != nil {
return err
}
if err := iprot.ReadMessageEnd(ctx); err != nil {
return err
}
return &exception
} else if rTypeId != REPLY {
return NewTApplicationException(INVALID_MESSAGE_TYPE_EXCEPTION, fmt.Sprintf("%s: invalid message type", method))
}
// result读取结果
if err := result.Read(ctx, iprot); err != nil {
return err
}
// 调用compact的ReadMessageEnd方法
return iprot.ReadMessageEnd(ctx)
}
看起来Recv方法其实和Send类似,读取的是compact的ReadMessageBegin和ReadMessageEnd方法,然后就是调用args的write和result.Read方法。
NewCalculatorClient
需要注意这个是通过thrift实现的方法,入参就是上面实现的TStandardClient结构体,然后实现了TClient这个interface也就是Call方法。
func NewCalculatorClient(c thrift.TClient) *CalculatorClient {
return &CalculatorClient{
SharedServiceClient: shared.NewSharedServiceClient(c),
}
}
func NewSharedServiceClient(c thrift.TClient) *SharedServiceClient {
return &SharedServiceClient{
c: c,
}
}
handleClient的具体实现
func handleClient(client *tutorial.CalculatorClient) (err error) {
// 调用ping方法
client.Ping(defaultCtx)
fmt.Println("ping()")
// 调用Add方法
sum, _ := client.Add(defaultCtx, 1, 1)
fmt.Print("1+1=", sum, "\n")
// 调用Calculate方法
work := tutorial.NewWork()
work.Op = tutorial.Operation_DIVIDE
work.Num1 = 1
work.Num2 = 0
quotient, err := client.Calculate(defaultCtx, 1, work)
if err != nil {
switch v := err.(type) {
case *tutorial.InvalidOperation:
fmt.Println("Invalid operation:", v)
default:
fmt.Println("Error during operation:", err)
}
} else {
fmt.Println("Whoa we can divide by 0 with new value:", quotient)
}
// 调用Calculate方法
work.Op = tutorial.Operation_SUBTRACT
work.Num1 = 15
work.Num2 = 10
diff, err := client.Calculate(defaultCtx, 1, work)
if err != nil {
switch v := err.(type) {
case *tutorial.InvalidOperation:
fmt.Println("Invalid operation:", v)
default:
fmt.Println("Error during operation:", err)
}
return err
} else {
fmt.Print("15-10=", diff, "\n")
}
// 调用GetStruct方法
log, err := client.GetStruct(defaultCtx, 1)
if err != nil {
fmt.Println("Unable to get struct:", err)
return err
} else {
fmt.Println("Check log:", log.Value)
}
return err
}
这个方法从使用上来看是好理解的,主要作用就是组装参数,调用client所实现的方法,接下来看一下这些方法在thrift中的实现。因为这里的实现都是大同小异的,所以这里看一下Add这个方法,代码如下
type CalculatorAddArgs struct {
Num1 int32 `thrift:"num1,1" db:"num1" json:"num1"`
Num2 int32 `thrift:"num2,2" db:"num2" json:"num2"`
}
// Attributes:
// - Success
type CalculatorAddResult struct {
Success *int32 `thrift:"success,0" db:"success" json:"success,omitempty"`
}
// ResponseMeta represents the metadata attached to the response.
type ResponseMeta struct {
// The headers in the response, if any.
// If the underlying transport/protocol is not THeader, this will always be nil.
Headers THeaderMap
}
// Parameters:
// - Num1
// - Num2
func (p *CalculatorClient) Add(ctx context.Context, num1 int32, num2 int32) (_r int32, _err error) {
var _args3 CalculatorAddArgs
_args3.Num1 = num1
_args3.Num2 = num2
var _result5 CalculatorAddResult
var _meta4 thrift.ResponseMeta
_meta4, _err = p.Client_().Call(ctx, "add", &_args3, &_result5)
p.SetLastResponseMeta_(_meta4)
if _err != nil {
return
}
return _result5.GetSuccess(), nil
}
这里的CalculatorAddArgs,CalculatorAddResult和ResponseMeta都如上所示,然后也是调用的Call方法。之所以要传入add是因为在server那边根据方法名称去获取不同的的handler,所以method名称是add。然后就是获取结果,然后调用CalculatorAddResult 中的success的成员,当然这里就是int32。
到这里基本上就说完了,当然这还远远不够,因为接下来需要说一下thrift中不同的proto对于TProtocol的具体实现。