0x01 代码
import os
import struct
import base64
import hashlib
import socket
try:
import thread
except ImportError:
import _thread as thread
class WebsocketServerBase(object):
""" websocket服务端基类
Details:
=================================================================
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-------+-+-------------+-------------------------------+
|F|R|R|R| opcode|M| Payload len | Extended payload length |
|I|S|S|S| (4) |A| (7) | (16/64) |
|N|V|V|V| |S| | (if payload len==126/127) |
| |1|2|3| |K| | |
+-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - +
| Extended payload length continued, if payload len == 127 |
+ - - - - - - - - - - - - - - - +-------------------------------+
| |Masking-key, if MASK set to 1 |
+-------------------------------+-------------------------------+
| Masking-key (continued) | Payload Data |
+-------------------------------- - - - - - - - - - - - - - - - +
: Payload Data continued ... :
+ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - +
| Payload Data continued ... |
+---------------------------------------------------------------+
Attributes:
gm_api_map: 全局存放api的散列表
timeout: 超时时间
"""
gm_api_map = {
"/": True
}
timeout = 30
def __init__(self):
""" 初始化方法 """
pass
def recv_data(self, conn):
""" 服务器解析浏览器发送的信息
Args:
conn: 接入的websocket客户端对象
"""
try:
all_data = conn.recv(1024, socket.MSG_WAITALL)
data_length = 0
if not len(all_data):
return str()
else:
code_len = ord(all_data[1]) & 127
if code_len == 126:
data_length = struct.unpack('>H', str(all_data[2:4]))[0]
masks = all_data[4:8]
data = all_data[8:]
elif code_len == 127:
data_length = struct.unpack('>Q', str(msg[2:10]))[0]
masks = all_data[10:14]
data = all_data[14:]
else:
data_length = code_len
masks = all_data[2:6]
data = all_data[6:]
raw_str = str()
i = 0
for d in data:
raw_str += chr(ord(d) ^ ord(masks[i % 4]))
i += 1
return raw_str
except:
return str()
def send_data(self, conn, data):
""" 服务器处理发送给客户端的信息
Args:
conn: 接入的websocket客户端对象
data: 待发送的数据
"""
if data:
data = str(data)
else:
return False
token = "\x81"
length = len(data)
if length < 126:
token += struct.pack("B", length)
elif length <= 0xFFFF:
token += struct.pack("!BH", 126, length)
else:
token += struct.pack("!BQ", 127, length)
data = '%s%s' % (token, data)
conn.send(data)
return True
def handshake(self, conn, address, thread_name):
""" 握手建立连接
Args:
conn: 接入的websocket客户端对象
address: 客户端地址[0]:ip, [1]port
thread_name: 线程名
"""
headers = {}
shake = conn.recv(1024)
if not len(shake):
return dict()
header, data = shake.split('\r\n\r\n', 1)
for line in header.split('\r\n'):
if 'GET' in line:
list_msg = line.split(' ')
headers['Api'] = list_msg[1]
headers['HttpVersion'] = list_msg[2]
else:
key, value = line.split(': ', 1)
if not not key:
headers[key] = value
if 'Sec-WebSocket-Key' not in headers:
self.log("INFO", '%s : This socket is not websocket, client close.' % thread_name)
self.close(conn)
return dict()
MAGIC_STRING = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
HANDSHAKE_STRING = "{0} 101 Switching Protocols\r\n" \
"Upgrade:websocket\r\n" \
"Connection: Upgrade\r\n" \
"Sec-WebSocket-Accept: {1}\r\n" \
"WebSocket-Origin: {2}\r\n" \
"WebSocket-Location: ws://{3}\r\n\r\n"
sec_key = headers['Sec-WebSocket-Key']
res_key = base64.b64encode(hashlib.sha1(sec_key + MAGIC_STRING).digest())
str_handshake = HANDSHAKE_STRING.replace('{0}', headers['HttpVersion']).replace('{1}', res_key).replace('{2}', headers['Origin']).replace('{3}', headers['Host']+headers['Api'])
conn.send(str_handshake)
is_resource_exist = False
if not not self.gm_api_map.get(headers['Api']):
is_resource_exist = True
if not is_resource_exist:
self.log("ERROR", headers['Api'] + ": RESOURCE NOT FOUND!")
self.send_data(conn, headers['Api'] + ": RESOURCE NOT FOUND!")
self.close(conn)
return dict()
headers['Address'] = "{0}:{1}".format(address[0], address[1])
self.log("INFO", headers)
self.log("INFO", '%s : Socket handshaken with %s:%s success' % (thread_name, address[0], address[1]))
return headers
def close(self, conn):
""" 关闭连接函数 """
conn.close()
def request_handle(self, conn, address, thread_name):
""" 主接收消息处理方法
Args:
conn: 接入的websocket客户端对象
address: 客户端地址[0]:ip, [1]port
thread_name: 线程名
"""
conn_headers = self.handshake(conn, address, thread_name)
if not conn_headers:
self.log("ERROR", "get the connect fault, close!")
return
conn.setblocking(True)
conn.settimeout(self.timeout)
while True:
try:
clientdata = self.recv_data(conn)
if not clientdata:
self.send_data(conn, 'close connect')
self.close(conn)
break
else:
ret = self.on_message(conn_headers, conn, clientdata)
if not ret:
break
except Exception as e:
self.on_error( e )
self.log("INFO", '%s : Socket close with %s:%s' % (thread_name, address[0], address[1]))
self.on_close()
def set_timeout(self, timeout):
""" 设置接收超时时间,到时间自动断开客户端连接,单位:s """
self.timeout = timeout
def ws_service(self, ip, port):
""" websocket主运行方法
Args:
ip: 本地服务端ip
port: 本地服务端绑定的端口
"""
index = 1
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.bind((ip, port))
sock.listen(100)
self.log("INFO", 'Websocket server start, wait for connect!')
while True:
try:
connection, address = sock.accept()
thread_name = 'thread_%s' % index
self.log("INFO", '%s : Connection from %s:%s' % (thread_name, address[0], address[1]))
thread.start_new_thread(self.request_handle, (connection, address, thread_name,))
self.on_open(connection)
index += 1
except Exception as e:
self.log("ERROR", e)
break
def register_api(self, api):
""" 列表
Args:
api: string类型,新增api接口/端点
"""
self.gm_api_map.update( {api:True} )
def on_open(self, conn):
""" 标准websocket的on_open方法
Args:
conn: 接入的websocket客户端对象
"""
pass
def on_message(self, conn_headers, conn, message):
""" 接收到的有效消息处理回调
Args:
conn_headers: 连接对象的请求头信息字典
conn: 接入的websocket客户端对象
message: 接收到的连接客户端对象消息
"""
return True
def on_close(self):
""" 关闭连接时的回调
"""
self.log("INFO", "close this connect!")
def on_error(self, error):
""" 触发异常时的回调
Args:
error: 异常消息
"""
self.log("ERROR", error)
def log(self, level, msg):
""" 日志输出信息,方便修改为不同的日志框架
Args:
level: 日志消息级别
msg: 日志消息
"""
msg = str(msg)
if "INFO" == level:
print('[INFO]:' + msg)
elif "ERROR" == level:
print('[ERROR]:' + msg)
elif "WARN" == level:
print('[WARN]:' + msg)
elif "DEBUG" == level:
print('[DEBUG]:' + msg)
else:
print("set level error, and reset on INFO!")
print('[INFO]:' + msg)
if __name__ == '__main__':
class TestWebsocketServer(WebsocketServerBase):
""" 仅在测试时候使用,不作为包
"""
def on_message(self, conn_headers, conn, message):
""" 继承自基类的接收消息回调方法 """
self.log("INFO", "recv form client: " + message)
if "close" in message:
self.close(conn)
return False
self.send_data(conn, message)
if "/getData" == conn_headers['Api']:
self.send_data(conn, "this is getData api")
return True
server = TestWebsocketServer()
server.register_api("/getData")
server.ws_service("0.0.0.0", 9997)
0x02 测试
0x03 问题
- 目前存在接收大量数据时,断连的情况,可继续优化。