1. 简单websocket 客户端实现
参考源码:
见:
1.1 核心源码解析
1.1.1 创建客户端对象并连接服务端
CWebsocket* ws_cli = new CWebsocket("ws://localhost:20600/ar/9999", true);
CWebsocket::CWebsocket(const string& url, bool use_mask) {
m_url = url;
sscanf(url.c_str(), "ws://%[^:/]:%d/%s", m_host, &m_port, m_path);
}
ret = ws_cli->connect_hostname();
int CWebsocket::connect_hostname() {
struct addrinfo hints;
struct addrinfo* result;
struct addrinfo* p;
int ret = 0;
std::string origin;
m_sock_fd = INVALID_SOCKET;
char sport[16];
memset(&hints, 0, sizeof(hints));
hints.ai_family = AF_UNSPEC;
hints.ai_socktype = SOCK_STREAM;
snprintf(sport, 16, "%d", m_port);
ret = getaddrinfo(m_host, sport, &hints, &result);
if (ret != 0) {
LIGHTWS_LOG("[ERR]Failed to getaddrinfo:%s\n", gai_strerror(ret));
return LIGHTWS_ERR_CONNECT;
}
for (p = result; p != NULL; p = p->ai_next) {
m_sock_fd = socket(p->ai_family, p->ai_socktype, p->ai_protocol);
if (m_sock_fd == INVALID_SOCKET) {
continue;
}
ret = connect(m_sock_fd, p->ai_addr, p->ai_addrlen);
if (ret != SOCKET_ERROR) {
break;
}
closesocket(m_sock_fd);
m_sock_fd = INVALID_SOCKET;
}
freeaddrinfo(result);
if (m_sock_fd == INVALID_SOCKET) {
LIGHTWS_LOG("[ERR]Unable to connect to %s:%d\n", m_host, m_port);
return LIGHTWS_ERR_CONNECT;
}
}
1.1.2 和服务端进行websocket握手
// Websocket handshake
{
char line[1024];
int status;
int i;
snprintf(line, 1024, "GET /%s HTTP/1.1\r\n", m_path);
::send(m_sock_fd, line, strlen(line), 0);
if (m_port == 80) {
snprintf(line, 1024, "Host: %s\r\n", m_host);
::send(m_sock_fd, line, strlen(line), 0);
}
else {
snprintf(line, 1024, "Host: %s:%d\r\n", m_host, m_port);
::send(m_sock_fd, line, strlen(line), 0);
}
snprintf(line, 1024, "Upgrade: websocket\r\n");
::send(m_sock_fd, line, strlen(line), 0);
snprintf(line, 1024, "Connection: Upgrade\r\n");
::send(m_sock_fd, line, strlen(line), 0);
if (!origin.empty()) {
snprintf(line, 1024, "Origin: %s\r\n", origin.c_str());
::send(m_sock_fd, line, strlen(line), 0);
}
snprintf(line, 1024, "Sec-WebSocket-Key: x3JJHMbDL1EzLkh9GBhXDw==\r\n");
::send(m_sock_fd, line, strlen(line), 0);
snprintf(line, 1024, "Sec-WebSocket-Version: 13\r\n");
::send(m_sock_fd, line, strlen(line), 0);
snprintf(line, 1024, "\r\n");
::send(m_sock_fd, line, strlen(line), 0);
for (i = 0; i < 2 || (i < 1023 && line[i-2] != '\r' && line[i-1] != '\n'); ++i) {
if (recv(m_sock_fd, line+i, 1, 0) == 0) {
return LIGTHWS_ERR_HANDSHAKE;
}
}
line[i] = 0;
if (i == 1023) {
LIGHTWS_LOG("ERROR: Got invalid status line connecting to: %s\n", m_url.c_str());
return LIGTHWS_ERR_HANDSHAKE;
}
if (sscanf(line, "HTTP/1.1 %d", &status) != 1 || status != 101) {
LIGHTWS_LOG("ERROR: Got bad status connecting to %s: %s", m_url.c_str(), line);
return LIGTHWS_ERR_HANDSHAKE;
}
// TODO: verify response headers,
while (true) {
for (i = 0; i < 2 || (i < 1023 && line[i-2] != '\r' && line[i-1] != '\n'); ++i) {
if (recv(m_sock_fd, line+i, 1, 0) == 0) {
return LIGTHWS_ERR_HANDSHAKE;
}
}
if (line[0] == '\r' && line[1] == '\n') {
break;
}
}
}
1.1.3 对要发送数据进行mask处理
const uint8_t masking_key[4] = { 0x12, 0x34, 0x56, 0x78 };
// TODO: consider acquiring a lock on txbuf...
if (m_ws_state == CLOSING || m_ws_state == CLOSED) {
return LIGHTWS_ERR_CLOSED;
}
std::vector<uint8_t> header;
header.assign(2 + (message_size >= 126 ? 2 : 0) + (message_size >= 65536 ? 6 : 0) + (m_use_mask ? 4 : 0), 0);
header[0] = 0x80 | type;
if (message_size < 126) {
header[1] = (message_size & 0xff) | (m_use_mask ? 0x80 : 0);
if (m_use_mask) {
header[2] = masking_key[0];
header[3] = masking_key[1];
header[4] = masking_key[2];
header[5] = masking_key[3];
}
}
else if (message_size < 65536) {
header[1] = 126 | (m_use_mask ? 0x80 : 0);
header[2] = (message_size >> 8) & 0xff;
header[3] = (message_size >> 0) & 0xff;
if (m_use_mask) {
header[4] = masking_key[0];
header[5] = masking_key[1];
header[6] = masking_key[2];
header[7] = masking_key[3];
}
}
else { // TODO: run coverage testing here
header[1] = 127 | (m_use_mask ? 0x80 : 0);
header[2] = (message_size >> 56) & 0xff;
header[3] = (message_size >> 48) & 0xff;
header[4] = (message_size >> 40) & 0xff;
header[5] = (message_size >> 32) & 0xff;
header[6] = (message_size >> 24) & 0xff;
header[7] = (message_size >> 16) & 0xff;
header[8] = (message_size >> 8) & 0xff;
header[9] = (message_size >> 0) & 0xff;
if (m_use_mask) {
header[10] = masking_key[0];
header[11] = masking_key[1];
header[12] = masking_key[2];
header[13] = masking_key[3];
}
}
// N.B. - txbuf will keep growing until it can be transmitted over the socket:
m_txbuf.insert(m_txbuf.end(), header.begin(), header.end());
m_txbuf.insert(m_txbuf.end(), message_begin, message_end);
if (m_use_mask) {
size_t message_offset = m_txbuf.size() - message_size;
for (size_t i = 0; i != message_size; ++i) {
m_txbuf[message_offset + i] ^= masking_key[i&0x3];
}
}
1.1.4 通过socket发送mask处理后的数据
// 发送request
while (m_txbuf.size()) {
int ret = ::send(m_sock_fd, (char*)&m_txbuf[0], m_txbuf.size(), 0);
if (false) { } //
else if (ret < 0 && (socketerrno == SOCKET_EWOULDBLOCK || socketerrno == SOCKET_EAGAIN_EINPROGRESS)) {
break;
}
else if (ret <= 0) {
closesocket(m_sock_fd);
m_ws_state = CLOSED;
fputs(ret < 0 ? "Connection error!\n" : "Connection closed!\n", stderr);
break;
}
else {
m_txbuf.erase(m_txbuf.begin(), m_txbuf.begin() + ret);
}
}
1.1.5 接收response
// 接收服务端的response
while (true) {
// FD_ISSET(0, &rfds) will be true
int N = m_rxbuf.size();
ssize_t ret;
m_rxbuf.resize(N + 1500);
ret = recv(m_sock_fd, (char*)&m_rxbuf[0] + N, 1500, 0);
if (false) { }
else if (ret < 0 && (socketerrno == SOCKET_EWOULDBLOCK || socketerrno == SOCKET_EAGAIN_EINPROGRESS)) {
m_rxbuf.resize(N);
break;
}
else if (ret <= 0) {
m_rxbuf.resize(N);
closesocket(m_sock_fd);
m_ws_state = CLOSED;
fputs(ret < 0 ? "Connection error!\n" : "Connection closed!\n", stderr);
break;
}
else {
m_rxbuf.resize(N + ret);
}
}
1.1.6 解析response数据
while (true) {
ws_header_type_t ws;
if (m_rxbuf.size() < 2) {
return LIGHTWS_ERR_DISPATCH; /* Need at least 2 */
}
const uint8_t * data = (uint8_t *) &m_rxbuf[0]; // peek, but don't consume
ws.fin = (data[0] & 0x80) == 0x80;
ws.opcode = (ws_header_type_t::EOpcodeType) (data[0] & 0x0f);
ws.mask = (data[1] & 0x80) == 0x80;
ws.N0 = (data[1] & 0x7f);
ws.header_size = 2 + (ws.N0 == 126? 2 : 0) + (ws.N0 == 127? 8 : 0) + (ws.mask? 4 : 0);
if (m_rxbuf.size() < ws.header_size) {
return LIGHTWS_ERR_DISPATCH; /* Need: ws.header_size - rxbuf.size() */
}
int i = 0;
if (ws.N0 < 126) {
ws.N = ws.N0;
i = 2;
}
else if (ws.N0 == 126) {
ws.N = 0;
ws.N |= ((uint64_t) data[2]) << 8;
ws.N |= ((uint64_t) data[3]) << 0;
i = 4;
}
else if (ws.N0 == 127) {
ws.N = 0;
ws.N |= ((uint64_t) data[2]) << 56;
ws.N |= ((uint64_t) data[3]) << 48;
ws.N |= ((uint64_t) data[4]) << 40;
ws.N |= ((uint64_t) data[5]) << 32;
ws.N |= ((uint64_t) data[6]) << 24;
ws.N |= ((uint64_t) data[7]) << 16;
ws.N |= ((uint64_t) data[8]) << 8;
ws.N |= ((uint64_t) data[9]) << 0;
i = 10;
if (ws.N & 0x8000000000000000ull) {
// https://tools.ietf.org/html/rfc6455 writes the "the most
// significant bit MUST be 0."
//
// We can't drop the frame, because (1) we don't we don't
// know how much data to skip over to find the next header,
// and (2) this would be an impractically long length, even
// if it were valid. So just close() and return immediately
// for now.
m_is_rx_bad = true;
LIGHTWS_LOG("ERROR: Frame has invalid frame length. Closing.\n");
close();
return LIGHTWS_ERR_DISPATCH;
}
}
if (ws.mask) {
ws.masking_key[0] = ((uint8_t) data[i+0]) << 0;
ws.masking_key[1] = ((uint8_t) data[i+1]) << 0;
ws.masking_key[2] = ((uint8_t) data[i+2]) << 0;
ws.masking_key[3] = ((uint8_t) data[i+3]) << 0;
}
else {
ws.masking_key[0] = 0;
ws.masking_key[1] = 0;
ws.masking_key[2] = 0;
ws.masking_key[3] = 0;
}
// Note: The checks above should hopefully ensure this addition
// cannot overflow:
if (m_rxbuf.size() < ws.header_size+ws.N) {
return LIGHTWS_ERR_DISPATCH; /* Need: ws.header_size+ws.N - rxbuf.size() */
}
// We got a whole message, now do something with it:
if (ws.opcode == ws_header_type_t::EOpcodeType::TEXT_FRAME
|| ws.opcode == ws_header_type_t::EOpcodeType::BINARY_FRAME
|| ws.opcode == ws_header_type_t::EOpcodeType::CONTINUATION) {
if (ws.mask) {
for (size_t i = 0; i != ws.N; ++i) {
m_rxbuf[i+ws.header_size] ^= ws.masking_key[i&0x3];
}
}
m_recved_data.insert(m_recved_data.end(), m_rxbuf.begin()+ws.header_size,
m_rxbuf.begin()+ws.header_size+(size_t)ws.N);// just feed
if (ws.fin) {
//callable((const std::vector<uint8_t>) m_recved_data);
std::string stringMessage(m_recved_data.begin(), m_recved_data.end());
printf(">>> %s\n", stringMessage.c_str());
m_recved_data.erase(m_recved_data.begin(), m_recved_data.end());
std::vector<uint8_t> ().swap(m_recved_data);// free memory
}
}
else if (ws.opcode == ws_header_type_t::EOpcodeType::PING) {
if (ws.mask) {
for (size_t i = 0; i != ws.N; ++i) {
m_rxbuf[i+ws.header_size] ^= ws.masking_key[i&0x3];
}
}
std::string data(m_rxbuf.begin()+ws.header_size, m_rxbuf.begin()+ws.header_size+(size_t)ws.N);
send_data(ws_header_type_t::PONG, data.size(), data.begin(), data.end());
}
else if (ws.opcode == ws_header_type_t::EOpcodeType::PONG) { }
else if (ws.opcode == ws_header_type_t::EOpcodeType::CLOSE) {
close();
}
else {
fprintf(stderr, "ERROR: Got unexpected WebSocket message.\n");
close();
}
m_rxbuf.erase(m_rxbuf.begin(), m_rxbuf.begin() + ws.header_size+(size_t)ws.N);
}
2. 简单websocket 服务端实现
完整源码:
2.1 核心源码解析
2.1.1 TCP服务端端口监听
// socket创建
listenfd_ = socket(AF_INET, SOCK_STREAM, 0);
// 绑定端口并启动监听
struct sockaddr_in server_addr;
memset(&server_addr, 0, sizeof(sockaddr_in));
server_addr.sin_family = AF_INET;
server_addr.sin_addr.s_addr = htonl(INADDR_ANY);
server_addr.sin_port = htons(PORT);
bind(listenfd_, (struct sockaddr *)(&server_addr), sizeof(server_addr))
listen(listenfd_, 5)
2.1.2 epoll创建并注册事件
/*
typedef std::map<int, Websocket_Handler *> WEB_SOCKET_HANDLER_MAP;
WEB_SOCKET_HANDLER_MAP websocket_handler_map_;
websocket_handler_map_ 是:
以socket的句柄为索引的Websocket_handler对象映射表,
以实现服务端对每个socket连接都有一个对应的对象来处理;
*/
// epoll句柄创建
epollfd_ = epoll_create(MAXEVENTSSIZE);
/* 注册监听端口socket的 EPOLLIN事件到epoll句柄 */
ctl_event(listenfd_, true);
void Network_Interface::ctl_event(int fd, bool flag){
struct epoll_event ev;
ev.data.fd = fd;
ev.events = flag ? EPOLLIN : 0;
epoll_ctl(epollfd_, flag ? EPOLL_CTL_ADD : EPOLL_CTL_DEL, fd, &ev);
if(flag){
set_noblock(fd);
websocket_handler_map_[fd] = new Websocket_Handler(fd);
if(fd != listenfd_) DEBUG_LOG("fd: %d 加入epoll循环", fd);
} else{
close(fd);
delete websocket_handler_map_[fd];
websocket_handler_map_.erase(fd);
DEBUG_LOG("fd: %d 退出epoll循环", fd);
}
}
2.1.3 epoll主循环
#define TIMEWAIT 100
int Network_Interface::epoll_loop(){
struct sockaddr_in client_addr;
struct epoll_event events[MAXEVENTSSIZE];
while(true){
nfds = epoll_wait(epollfd_, events, MAXEVENTSSIZE, TIMEWAIT);
for(int i = 0; i < nfds; i++){
// 接受客户端的连接请求,
// 并注册这个连接socket句柄的EPOLLIN事件监听到epoll
if(events[i].data.fd == listenfd_){
fd = accept(listenfd_, (struct sockaddr *)&client_addr, &clilen);
ctl_event(fd, true);
}
// 接收到了客户端发送的数据,
// 将数据读取到句柄的buffer, 然后进行处理
else if(events[i].events & EPOLLIN){
Websocket_Handler *handler = websocket_handler_map_[fd];
bufflen = read(fd, handler->getbuff(), BUFFLEN)
handler->process();
}
}
}
return 0;
}
2.1.4 websocket的 request处理
int Websocket_Handler::process(){
// 第一次时需要进行websocket握手
if(status_ == WEBSOCKET_UNCONNECT){
return handshark();
}
// 非第一次时,则进行数据处理:
// 提取数据,并印出来
request_->fetch_websocket_info(buff_);
request_->print();
memset(buff_, 0, sizeof(buff_));
return 0;
}
2.1.4.1 websocket的握手
主要是根据Websocket握手包进行解析,
然后根据Sec-WebSocket-Key进行SHA1哈希,生成相应的key,
返回给客户端,与客户端进行握手;
int Websocket_Handler::handshark(){
char request[1024] = {};
status_ = WEBSOCKET_HANDSHARKED;
fetch_http_info(); // 提取出HTTP头的每行
parse_str(request); // 解析头,并生成新的key
memset(buff_, 0, sizeof(buff_));
// 将新的握手数据发给客户端,完成握手;
return send_data(request);
}
// 逐行解析,获得Websocket的数据
int Websocket_Handler::fetch_http_info(){
std::istringstream s(buff_);
std::string request;
std::getline(s, request);
if (request[request.size()-1] == '\r') {
request.erase(request.end()-1);
} else {
return -1;
}
std::string header;
std::string::size_type end;
while (std::getline(s, header) && header != "\r") {
if (header[header.size()-1] != '\r') {
continue; //end
} else {
header.erase(header.end()-1); //remove last char
}
end = header.find(": ",0);
if (end != std::string::npos) {
std::string key = header.substr(0,end);
std::string value = header.substr(end+2);
header_map_[key] = value;
}
}
return 0;
}
// 检查每一项具体HTTP头的内容,重新生成新的key
void Websocket_Handler::parse_str(char *request){
strcat(request, "HTTP/1.1 101 Switching Protocols\r\n");
strcat(request, "Connection: upgrade\r\n");
strcat(request, "Sec-WebSocket-Accept: ");
std::string server_key = header_map_["Sec-WebSocket-Key"];
server_key += MAGIC_KEY;
SHA1 sha;
unsigned int message_digest[5];
sha.Reset();
sha << server_key.c_str();
sha.Result(message_digest);
for (int i = 0; i < 5; i++) {
message_digest[i] = htonl(message_digest[i]);
}
server_key = base64_encode(reinterpret_cast<const unsigned char*>(message_digest),20);
server_key += "\r\n";
strcat(request, server_key.c_str());
strcat(request, "Upgrade: websocket\r\n\r\n");
}
int Websocket_Handler::send_data(char *buff){
return write(fd_, buff, strlen(buff));
}
2.1.4.2 非握手的数据处理
提取数据头和数据body;
int Websocket_Request::fetch_websocket_info(char *msg){
int pos = 0;
fetch_fin(msg, pos);
fetch_opcode(msg, pos);
fetch_mask(msg, pos);
fetch_payload_length(msg, pos);
fetch_masking_key(msg, pos);
return fetch_payload(msg, pos);
}
2.1.5 websocket的response处理
没有实现;