以 thrift 框架简单使用 这篇文章为例,用gdb 跟踪一下server 的执行流程,以此学习下thrift 的源码。在此先贴出服务端代码:
#include "Hello.h"
#include <thrift/protocol/TBinaryProtocol.h>
#include <thrift/server/TSimpleServer.h>
#include <thrift/transport/TServerSocket.h>
#include <thrift/transport/TBufferTransports.h>
using namespace ::apache::thrift;
using namespace ::apache::thrift::protocol;
using namespace ::apache::thrift::transport;
using namespace ::apache::thrift::server;
using boost::shared_ptr;
using namespace ::demo;
class HelloHandler : virtual public HelloIf {
public:
HelloHandler() {
// Your initialization goes here
}
void helloString(helloOut& _return, const helloIn& sIn) {
// Your implementation goes here
printf("server get name:%s, age:%d\n", sIn.name.c_str(), sIn.age);
_return.msg = "Hello";
_return.resCode = 0;
}
};
int main(int argc, char **argv) {
int port = 9090;
shared_ptr<HelloHandler> handler(new HelloHandler());
shared_ptr<TProcessor> processor(new HelloProcessor(handler));
shared_ptr<TServerTransport> serverTransport(new TServerSocket(port));
shared_ptr<TTransportFactory> transportFactory(new TBufferedTransportFactory());
shared_ptr<TProtocolFactory> protocolFactory(new TBinaryProtocolFactory());
TSimpleServer server(processor, serverTransport, transportFactory, protocolFactory);
server.serve();
return 0;
}
服务端的行为主要在serve,用gdb --args ./server ,看看serve 做了什么
serve
src/thrift/server/TServerFramework.cpp
void TServerFramework::serve() {
shared_ptr<TTransport> client;
shared_ptr<TTransport> inputTransport;
shared_ptr<TTransport> outputTransport;
shared_ptr<TProtocol> inputProtocol;
shared_ptr<TProtocol> outputProtocol;
// 启动监听
serverTransport_->listen();
// serve 前的准备
if (eventHandler_) {
eventHandler_->preServe();
}
// 获取客户端连接
for (;;) {
try {
// 释放前一个客户端连接资源,以便使用完后可以释放(参考shared_ptr的特性)
outputProtocol.reset();
inputProtocol.reset();
outputTransport.reset();
inputTransport.reset();
client.reset();
// 如果超出了允许并发的客户端数上限,等待直到并发数降低到limit_下 这里mon_ 作用类似条件锁
{
Synchronized sync(mon_);
while (clients_ >= limit_) {
mon_.wait();
}
}
client = serverTransport_->accept(); // 监听客户端连接,此时另开一个终端,执行./client 请求,触发这一步
inputTransport = inputTransportFactory_->getTransport(client);
outputTransport = outputTransportFactory_->getTransport(client);
if (!outputProtocolFactory_) {
inputProtocol = inputProtocolFactory_->getProtocol(inputTransport, outputTransport);
outputProtocol = inputProtocol;
} else {
inputProtocol = inputProtocolFactory_->getProtocol(inputTransport);
outputProtocol = outputProtocolFactory_->getProtocol(outputTransport);
}
// 处理客户端连接
newlyConnectedClient(shared_ptr<TConnectedClient>(
new TConnectedClient(getProcessor(inputProtocol, outputProtocol, client),
inputProtocol,
outputProtocol,
eventHandler_,
client),
bind(&TServerFramework::disposeConnectedClient, this, std::placeholders::_1)));
} catch (TTransportException& ttx) {
releaseOneDescriptor("inputTransport", inputTransport);
releaseOneDescriptor("outputTransport", outputTransport);
releaseOneDescriptor("client", client);
if (ttx.getType() == TTransportException::TIMED_OUT) {
// Accept timeout - continue processing.
continue;
} else if (ttx.getType() == TTransportException::END_OF_FILE
|| ttx.getType() == TTransportException::INTERRUPTED) {
// Server was interrupted. This only happens when stopping.
break;
} else {
// All other transport exceptions are logged.
// State of connection is unknown. Done.
string errStr = string("TServerTransport died: ") + ttx.what();
GlobalOutput(errStr.c_str());
break;
}
}
}
releaseOneDescriptor("serverTransport", serverTransport_);
}
几个关键点:listen、accept、newlyConnectedClient,其中accept 会在客户端 transport->open() 时返回。
listen
void TServerSocket::listen() {
listening_ = true;
……
// Validate port number
if (port_ < 0 || port_ > 0xFFFF) {
throw TTransportException(TTransportException::BAD_ARGS, "Specified port is invalid");
}
const struct addrinfo *res;
int error;
char port[sizeof("65535")];
THRIFT_SNPRINTF(port, sizeof(port), "%d", port_);
struct addrinfo hints;
std::memset(&hints, 0, sizeof(hints));
hints.ai_family = PF_UNSPEC;
hints.ai_socktype = SOCK_STREAM;
hints.ai_flags = AI_PASSIVE | AI_ADDRCONFIG;
// If address is not specified use wildcard address (NULL)
TGetAddrInfoWrapper info(address_.empty() ? nullptr : &address_[0], port, &hints);
error = info.init();
if (error) {
……
}
// Pick the ipv6 address first since ipv4 addresses can be mapped
// into ipv6 space.
for (res = info.res(); res; res = res->ai_next) {
if (res->ai_family == AF_INET6 || res->ai_next == nullptr)
break;
}
serverSocket_ = socket(res->ai_family, res->ai_socktype, res->ai_protocol); //
// 设置 THRIFT_NO_SOCKET_CACHING-> SO_REUSEADDR 选项 取消复用同一端口的 2MSL 延时要求,见TCP/IP详解
int one = 1;
if (-1 == setsockopt(serverSocket_,
SOL_SOCKET,
THRIFT_NO_SOCKET_CACHING,
cast_sockopt(&one),
sizeof(one))) {
……
}
……
// 推迟 accept 直到 client 发数据才返回,这里待验证 gdb --args ./client 看看……
#ifdef TCP_DEFER_ACCEPT
if (path_.empty()) {
if (-1 == setsockopt(serverSocket_, IPPROTO_TCP, TCP_DEFER_ACCEPT, &one, sizeof(one))) {
int errno_copy = THRIFT_GET_SOCKET_ERROR;
GlobalOutput.perror("TServerSocket::listen() setsockopt() TCP_DEFER_ACCEPT ", errno_copy);
close();
throw TTransportException(TTransportException::NOT_OPEN,
"Could not set TCP_DEFER_ACCEPT",
errno_copy);
}
}
#endif
// 关闭 linger , close 调用后直接返回,缓冲未发送数据会丢失
struct linger ling = {0, 0};
if (-1 == setsockopt(serverSocket_, SOL_SOCKET, SO_LINGER, cast_sockopt(&ling), sizeof(ling))) {
int errno_copy = THRIFT_GET_SOCKET_ERROR;
GlobalOutput.perror("TServerSocket::listen() setsockopt() SO_LINGER ", errno_copy);
close();
throw TTransportException(TTransportException::NOT_OPEN, "Could not set SO_LINGER", errno_copy);
}
// TCP_NODELAY 关闭 nagle 算法
if (path_.empty()) {
// TCP Nodelay, speed over bandwidth
if (-1
== setsockopt(serverSocket_, IPPROTO_TCP, TCP_NODELAY, cast_sockopt(&one), sizeof(one))) {
int errno_copy = THRIFT_GET_SOCKET_ERROR;
GlobalOutput.perror("TServerSocket::listen() setsockopt() TCP_NODELAY ", errno_copy);
close();
throw TTransportException(TTransportException::NOT_OPEN,
"Could not set TCP_NODELAY",
errno_copy);
}
}
// 设置socket为非阻塞
int flags = THRIFT_FCNTL(serverSocket_, THRIFT_F_GETFL, 0);
if (flags == -1) {
int errno_copy = THRIFT_GET_SOCKET_ERROR;
GlobalOutput.perror("TServerSocket::listen() THRIFT_FCNTL() THRIFT_F_GETFL ", errno_copy);
close();
throw TTransportException(TTransportException::NOT_OPEN,
"THRIFT_FCNTL() THRIFT_F_GETFL failed",
errno_copy);
}
if (-1 == THRIFT_FCNTL(serverSocket_, THRIFT_F_SETFL, flags | THRIFT_O_NONBLOCK)) {
int errno_copy = THRIFT_GET_SOCKET_ERROR;
GlobalOutput.perror("TServerSocket::listen() THRIFT_FCNTL() THRIFT_O_NONBLOCK ", errno_copy);
close();
throw TTransportException(TTransportException::NOT_OPEN,
"THRIFT_FCNTL() THRIFT_F_SETFL THRIFT_O_NONBLOCK failed",
errno_copy);
}
int retries = 0;
int errno_copy = 0;
// 绑定端口
if (!path_.empty()) {
……
} else {
do {
if (0 == ::bind(serverSocket_, res->ai_addr, static_cast<int>(res->ai_addrlen))) {
break;
}
errno_copy = THRIFT_GET_SOCKET_ERROR;
// use short circuit evaluation here to only sleep if we need to
} while ((retries++ < retryLimit_) && (THRIFT_SLEEP_SEC(retryDelay_) == 0));
……
// 监听
if (-1 == ::listen(serverSocket_, acceptBacklog_)) {
errno_copy = THRIFT_GET_SOCKET_ERROR;
GlobalOutput.perror("TServerSocket::listen() listen() ", errno_copy);
close();
throw TTransportException(TTransportException::NOT_OPEN, "Could not listen", errno_copy);
}
}
这里删了一些无关代码,但还是很长, 其实主要是bind 和 listen。
accept
./src/thrift/transport/TServerTransport.h
std::shared_ptr<TTransport> accept() {
std::shared_ptr<TTransport> result = acceptImpl();
if (!result) {
throw TTransportException("accept() may not return NULL");
}
return result;
}
src/thrift/transport/TServerSocket.cpp
// 代码有所删减
shared_ptr<TTransport> TServerSocket::acceptImpl() {
if (serverSocket_ == THRIFT_INVALID_SOCKET) {
throw TTransportException(TTransportException::NOT_OPEN, "TServerSocket not listening");
}
struct THRIFT_POLLFD fds[2];
while (true) {
std::memset(fds, 0, sizeof(fds));
fds[0].fd = serverSocket_;
fds[0].events = THRIFT_POLLIN;
if (interruptSockReader_ != THRIFT_INVALID_SOCKET) {
fds[1].fd = interruptSockReader_;
fds[1].events = THRIFT_POLLIN;
}
int ret = THRIFT_POLL(fds, 2, accTimeout_);
……
// Check for the actual server socket being ready
if (fds[0].revents & THRIFT_POLLIN) { // 客户端发起了 connect
break;
}
} else {
GlobalOutput("TServerSocket::acceptImpl() THRIFT_POLL 0");
throw TTransportException(TTransportException::UNKNOWN);
}
}
struct sockaddr_storage clientAddress;
int size = sizeof(clientAddress);
// 上面 poll 已经返回,这里accept 不会阻塞
THRIFT_SOCKET clientSocket = ::accept(serverSocket_, (struct sockaddr*)&clientAddress, (socklen_t*)&size);
// 非阻塞
int flags = THRIFT_FCNTL(clientSocket, THRIFT_F_GETFL, 0);
if (flags == -1) {
int errno_copy = THRIFT_GET_SOCKET_ERROR;
::THRIFT_CLOSESOCKET(clientSocket);
GlobalOutput.perror("TServerSocket::acceptImpl() THRIFT_FCNTL() THRIFT_F_GETFL ", errno_copy);
throw TTransportException(TTransportException::UNKNOWN,
"THRIFT_FCNTL(THRIFT_F_GETFL)",
errno_copy);
}
if (-1 == THRIFT_FCNTL(clientSocket, THRIFT_F_SETFL, flags & ~THRIFT_O_NONBLOCK)) {
int errno_copy = THRIFT_GET_SOCKET_ERROR;
::THRIFT_CLOSESOCKET(clientSocket);
GlobalOutput
.perror("TServerSocket::acceptImpl() THRIFT_FCNTL() THRIFT_F_SETFL ~THRIFT_O_NONBLOCK ",
errno_copy);
throw TTransportException(TTransportException::UNKNOWN,
"THRIFT_FCNTL(THRIFT_F_SETFL)",
errno_copy);
}
shared_ptr<TSocket> client = createSocket(clientSocket); // 将 socket 封装为 TSocket
if (sendTimeout_ > 0) {
client->setSendTimeout(sendTimeout_);
}
if (recvTimeout_ > 0) {
client->setRecvTimeout(recvTimeout_);
}
if (keepAlive_) {
client->setKeepAlive(keepAlive_);
}
client->setCachedAddress((sockaddr*)&clientAddress, size);
if (acceptCallback_)
acceptCallback_(clientSocket);
return client;
}
这里建立了和客户端的连接。
newlyConnectedClient
src/thrift/server/TServerFramework.cpp
void TServerFramework::newlyConnectedClient(const shared_ptr<TConnectedClient>& pClient) {
{
Synchronized sync(mon_);
++clients_; // 客户端并发数自增
hwm_ = (std::max)(hwm_, clients_);
}
onClientConnected(pClient);
}
onClientConnected
src/thrift/server/TSimpleServer.cpp
void TSimpleServer::onClientConnected(const shared_ptr<TConnectedClient>& pClient) {
pClient->run();
}
run
src/thrift/server/TConnectedClient.cpp
void TConnectedClient::run() {
if (eventHandler_) {
opaqueContext_ = eventHandler_->createContext(inputProtocol_, outputProtocol_);
}
for (bool done = false; !done;) {
if (eventHandler_) {
eventHandler_->processContext(opaqueContext_, client_);
}
try {
if (!processor_->process(inputProtocol_, outputProtocol_, opaqueContext_)) { // 处理请求的逻辑
break;
}
} catch (const TTransportException& ttx) {
switch (ttx.getType()) {
case TTransportException::END_OF_FILE:
case TTransportException::INTERRUPTED:
case TTransportException::TIMED_OUT:
// Client disconnected or was interrupted or did not respond within the receive timeout.
// No logging needed. Done.
done = true;
break;
default: {
// All other transport exceptions are logged.
// State of connection is unknown. Done.
string errStr = string("TConnectedClient died: ") + ttx.what();
GlobalOutput(errStr.c_str());
done = true;
break;
}
}
} catch (const TException& tex) {
string errStr = string("TConnectedClient processing exception: ") + tex.what();
GlobalOutput(errStr.c_str());
// Disconnect from client, because we could not process the message.
done = true;
}
}
cleanup();
}
process
/usr/local/include/thrift/TDispatchProcessor.h
class TDispatchProcessor : public TProcessor {
public:
virtual bool process(boost::shared_ptr<protocol::TProtocol> in,
boost::shared_ptr<protocol::TProtocol> out,
void* connectionContext) {
std::string fname;
protocol::TMessageType mtype;
int32_t seqid;
in->readMessageBegin(fname, mtype, seqid);
if (mtype != protocol::T_CALL && mtype != protocol::T_ONEWAY) {
GlobalOutput.printf("received invalid message type %d from client", mtype);
return false;
}
return dispatchCall(in.get(), out.get(), fname, seqid, connectionContext); // 请求分发
}
这里 readMessageBegin 会阻塞直到客户端调用 rpc 方法 helloString 时返回。
dispatchCall
这里就到thrift 自动生成的代码 gen-cpp/Hello.cpp
bool HelloProcessor::dispatchCall(::apache::thrift::protocol::TProtocol* iprot, ::apache::thrift::protocol::TProtocol* oprot, const std::string& fname, int32_t seqid, void* callContext) {
ProcessMap::iterator pfn;
pfn = processMap_.find(fname);
if (pfn == processMap_.end()) {
iprot->skip(::apache::thrift::protocol::T_STRUCT);
iprot->readMessageEnd();
iprot->getTransport()->readEnd();
::apache::thrift::TApplicationException x(::apache::thrift::TApplicationException::UNKNOWN_METHOD, "Invalid method name: '"+fname+"'");
oprot->writeMessageBegin(fname, ::apache::thrift::protocol::T_EXCEPTION, seqid);
x.write(oprot);
oprot->writeMessageEnd();
oprot->getTransport()->writeEnd();
oprot->getTransport()->flush();
return true;
}
(this->*(pfn->second))(seqid, iprot, oprot, callContext);
return true;
}
这里的关键在: (this->*(pfn->second))(seqid, iprot, oprot, callContext); 这一句
fname 是客户端传来的rpc方法名,而server保存的是一个map,key是方法名,value是函数地址,所以看到的是pfn->second 的调用,详细可以看Hello.h 这个文件:
processMap_["helloString"] = &HelloProcessor::process_helloString;
process_helloString
void HelloProcessor::process_helloString(int32_t seqid, ::apache::thrift::protocol::TProtocol* iprot, ::apache::thrift::protocol::TProtocol* oprot, void* callContext)
{
void* ctx = NULL;
if (this->eventHandler_.get() != NULL) {
ctx = this->eventHandler_->getContext("Hello.helloString", callContext);
}
::apache::thrift::TProcessorContextFreer freer(this->eventHandler_.get(), ctx, "Hello.helloString");
if (this->eventHandler_.get() != NULL) {
this->eventHandler_->preRead(ctx, "Hello.helloString");
}
Hello_helloString_args args;
args.read(iprot);
iprot->readMessageEnd();
uint32_t bytes = iprot->getTransport()->readEnd();
if (this->eventHandler_.get() != NULL) {
this->eventHandler_->postRead(ctx, "Hello.helloString", bytes);
}
Hello_helloString_result result;
try {
iface_->helloString(result.success, args.sIn);
result.__isset.success = true;
} catch (const std::exception& e) {
if (this->eventHandler_.get() != NULL) {
this->eventHandler_->handlerError(ctx, "Hello.helloString");
}
::apache::thrift::TApplicationException x(e.what());
oprot->writeMessageBegin("helloString", ::apache::thrift::protocol::T_EXCEPTION, seqid);
x.write(oprot);
oprot->writeMessageEnd();
oprot->getTransport()->writeEnd();
oprot->getTransport()->flush();
return;
}
if (this->eventHandler_.get() != NULL) {
this->eventHandler_->preWrite(ctx, "Hello.helloString");
}
oprot->writeMessageBegin("helloString", ::apache::thrift::protocol::T_REPLY, seqid);
result.write(oprot);
oprot->writeMessageEnd();
bytes = oprot->getTransport()->writeEnd();
oprot->getTransport()->flush();
if (this->eventHandler_.get() != NULL) {
this->eventHandler_->postWrite(ctx, "Hello.helloString", bytes);
}
}
函数的本地调用,还有结果的回包。
这里涉及的细节还是挺多,但是用gdb调试还是很方便的。